Commit 4f7a8fc5 authored by Matteo Spanio's avatar Matteo Spanio
Browse files

Upload New File

parent c6f06ab3
Pipeline #28 canceled with stages
"""
MPAI CAE-ARP Tape Irregularity Classifier.
Implements MPAI CAE-ARP Tape Irregularity Classifier Technical Specification.
"""
import json
import os
from os import path
import cv2
import numpy as np
from tensorflow import keras
from mpai_cae_arp.io import Color, pprint
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.types.irregularity import Irregularity, IrregularityFile, IrregularityType
__copyright__ = "Copyright 2022, Audio Innova S.r.l."
__credits__ = ["Niccolò Pretto", "Nadir Dalla Pozza", "Sergio Canazza"]
__status__ = "Production"
MODEL_PATH = './model/model_ROI'
CLASSES = [
'sp', # 0: Splice
's' # 1: Shadow
]
# Minimum probability required for classifying an Irregularity as interesting
PROB = 75
def verify_path(working_path: str, files_name: str) -> str:
"""
Method to check that the environment is conformant to the standard.
Parameters
----------
working_path : str
the path where all files resulting from previous AIMs are stored,
files_name : str
the Preservation files name, to identify the input directory.
Raises
------
FileNotFoundError
if the specified WORKING_PATH is non-existent, or if the specified WORKING_PATH structure is not conformant, or if the specified FILES_NAME is non-existent.
Returns
-------
str
the path where the files to be processed during the current execution are stored.
"""
if not path.exists(working_path):
raise FileNotFoundError("The specified WORKING_PATH is non-existent!")
temp_path = path.join(working_path, 'temp')
if not os.path.exists(temp_path):
raise FileNotFoundError("The specified WORKING_PATH structure is not conformant!")
temp_path = path.join(temp_path, files_name)
if not os.path.exists(temp_path):
raise FileNotFoundError("The specified FILES_NAME is non-existent!")
return temp_path
def collect_irregularity_images(irregularities: list[Irregularity]) -> list[str]:
"""
Collects the images of the Irregularities and returns them in a list. If an Irregularity Image is not found, it is removed from the list.
"""
images = []
for irr in irregularities:
img_path = irr.image_URI
if path.exists(img_path):
# The Irregularity Image was found, then save its data
img = cv2.imread(img_path)
# All images must be resized to 224x224 due to network requirements
resized = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
resized.astype('float32')
images.append(resized)
else:
irregularities.remove(irr)
pprint("Irregularity Image not found!", color=Color.YELLOW)
pprint(f" ImageURI: {img_path}", color=Color.YELLOW)
pprint(f" IrregularityID: {irr.irregularity_ID}", color=Color.YELLOW)
return images
def get_irregularity_file(filepath: str) -> IrregularityFile:
"""
Returns the Irregularity File from the specified path, or quits if the file is not found.
Parameters
----------
path : str
The path of the Irregularity File.
Raises
------
FileNotFoundError
If the file is not found.
json.decoder.JSONDecodeError
If the file is not a valid JSON file.
Returns
-------
IrregularityFile
The Irregularity File from the specified path.
"""
try:
irregularity_file = File(filepath, FileType.JSON).get_content()
return IrregularityFile.from_json(irregularity_file)
except FileNotFoundError:
raise FileNotFoundError(f"{filepath} not found!")
except json.decoder.JSONDecodeError:
raise json.decoder.JSONDecodeError(f"{filepath} is not a valid JSON file!")
def load_model(path_to_model):
model = keras.models.model_from_json(open(f"{path_to_model}.json", 'r').read())
model.load_weights(f'{path_to_model}.h5')
return model
def prepare_images_for_prediction(images: list[str]) -> np.ndarray:
"""
Prepares the images for prediction by resizing them to 224x224 and scaling the BGR values in [0, 1].
"""
images = np.stack(images, axis=0)
# Scale BGR values in [0, 1]
return images/255
def predict(model, images: list[str]):
"""
Predicts the probability of interest for each image.
"""
prediction = model.predict(images, batch_size=1)
# We are interested in the first array of prediction probabilities
probabilities = prediction[0]
# Get the position where the highest probability is stored
# It corresponds to the class index (0, 1, ...)
y = np.argmax(probabilities, axis=1)
return probabilities, y
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment