lib.py 4.62 KB
Newer Older
Matteo Spanio's avatar
Matteo Spanio committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
MPAI CAE-ARP Tape Irregularity Classifier.

Implements MPAI CAE-ARP Tape Irregularity Classifier Technical Specification.
"""

import json
import os
from os import path

import cv2
import numpy as np

from tensorflow import keras

from mpai_cae_arp.io import Color, pprint
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.types.irregularity import Irregularity, IrregularityFile, IrregularityType

__copyright__ = "Copyright 2022, Audio Innova S.r.l."
__credits__ = ["Niccolò Pretto", "Nadir Dalla Pozza", "Sergio Canazza"]
__status__ = "Production"


MODEL_PATH = './model/model_ROI'

CLASSES = [
    'sp',  # 0: Splice
    's'    # 1: Shadow
]

# Minimum probability required for classifying an Irregularity as interesting
PROB = 75


def verify_path(working_path: str, files_name: str) -> str:
    """
    Method to check that the environment is conformant to the standard.

    Parameters
    ----------
    working_path : str
        the path where all files resulting from previous AIMs are stored,
    files_name : str
        the Preservation files name, to identify the input directory.

    Raises
    ------
    FileNotFoundError
        if the specified WORKING_PATH is non-existent, or if the specified WORKING_PATH structure is not conformant, or if the specified FILES_NAME is non-existent.

    Returns
    -------
    str
        the path where the files to be processed during the current execution are stored.
    """
    if not path.exists(working_path):
        raise FileNotFoundError("The specified WORKING_PATH is non-existent!")
    temp_path = path.join(working_path, 'temp')
    if not os.path.exists(temp_path):
        raise FileNotFoundError("The specified WORKING_PATH structure is not conformant!")
    temp_path = path.join(temp_path, files_name)
    if not os.path.exists(temp_path):
        raise FileNotFoundError("The specified FILES_NAME is non-existent!")
    return temp_path


def collect_irregularity_images(irregularities: list[Irregularity]) -> list[str]:
    """
    Collects the images of the Irregularities and returns them in a list. If an Irregularity Image is not found, it is removed from the list.
    """
    images = []
    for irr in irregularities:
        img_path = irr.image_URI
        if path.exists(img_path):
            # The Irregularity Image was found, then save its data
            img = cv2.imread(img_path)
            # All images must be resized to 224x224 due to network requirements
            resized = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
            resized.astype('float32')
            images.append(resized)
        else:
            irregularities.remove(irr)
            pprint("Irregularity Image not found!", color=Color.YELLOW)
            pprint(f"    ImageURI: {img_path}", color=Color.YELLOW)
            pprint(f"    IrregularityID: {irr.irregularity_ID}", color=Color.YELLOW)

    return images


def get_irregularity_file(filepath: str) -> IrregularityFile:
    """
    Returns the Irregularity File from the specified path, or quits if the file is not found.

    Parameters
    ----------
    path : str
        The path of the Irregularity File.

    Raises
    ------
    FileNotFoundError
        If the file is not found.
    json.decoder.JSONDecodeError
        If the file is not a valid JSON file.

    Returns
    -------
    IrregularityFile
        The Irregularity File from the specified path.
    """
    try:
        irregularity_file = File(filepath, FileType.JSON).get_content()
        return IrregularityFile.from_json(irregularity_file)
    except FileNotFoundError:
        raise FileNotFoundError(f"{filepath} not found!")
    except json.decoder.JSONDecodeError:
        raise json.decoder.JSONDecodeError(f"{filepath} is not a valid JSON file!")

    
def load_model(path_to_model):
    model = keras.models.model_from_json(open(f"{path_to_model}.json", 'r').read())
    model.load_weights(f'{path_to_model}.h5')
    return model


def prepare_images_for_prediction(images: list[str]) -> np.ndarray:
    """
    Prepares the images for prediction by resizing them to 224x224 and scaling the BGR values in [0, 1].
    """
    
    images = np.stack(images, axis=0)
    # Scale BGR values in [0, 1]
    return images/255


def predict(model, images: list[str]):
    """
    Predicts the probability of interest for each image.
    """
    prediction = model.predict(images, batch_size=1)
    # We are interested in the first array of prediction probabilities
    probabilities = prediction[0]
    # Get the position where the highest probability is stored
    # It corresponds to the class index (0, 1, ...)
    y = np.argmax(probabilities, axis=1)
    return probabilities, y