"""
MPAI CAE-ARP Tape Irregularity Classifier.

Implements MPAI CAE-ARP Tape Irregularity Classifier Technical Specification.
"""

import json
import os
from os import path

import cv2
import numpy as np

from tensorflow import keras

from mpai_cae_arp.io import Color, pprint
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.types.irregularity import Irregularity, IrregularityFile, IrregularityType

__copyright__ = "Copyright 2022, Audio Innova S.r.l."
__credits__ = ["Niccolò Pretto", "Nadir Dalla Pozza", "Sergio Canazza"]
__status__ = "Production"


MODEL_PATH = './model/model_ROI'

CLASSES = [
    'sp',  # 0: Splice
    's'    # 1: Shadow
]

# Minimum probability required for classifying an Irregularity as interesting
PROB = 75


def verify_path(working_path: str, files_name: str) -> str:
    """
    Method to check that the environment is conformant to the standard.

    Parameters
    ----------
    working_path : str
        the path where all files resulting from previous AIMs are stored,
    files_name : str
        the Preservation files name, to identify the input directory.

    Raises
    ------
    FileNotFoundError
        if the specified WORKING_PATH is non-existent, or if the specified WORKING_PATH structure is not conformant, or if the specified FILES_NAME is non-existent.

    Returns
    -------
    str
        the path where the files to be processed during the current execution are stored.
    """
    if not path.exists(working_path):
        raise FileNotFoundError("The specified WORKING_PATH is non-existent!")
    temp_path = path.join(working_path, 'temp')
    if not os.path.exists(temp_path):
        raise FileNotFoundError("The specified WORKING_PATH structure is not conformant!")
    temp_path = path.join(temp_path, files_name)
    if not os.path.exists(temp_path):
        raise FileNotFoundError("The specified FILES_NAME is non-existent!")
    return temp_path


def collect_irregularity_images(irregularities: list[Irregularity]) -> list[str]:
    """
    Collects the images of the Irregularities and returns them in a list. If an Irregularity Image is not found, it is removed from the list.
    """
    images = []
    for irr in irregularities:
        img_path = irr.image_URI
        if path.exists(img_path):
            # The Irregularity Image was found, then save its data
            img = cv2.imread(img_path)
            # All images must be resized to 224x224 due to network requirements
            resized = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
            resized.astype('float32')
            images.append(resized)
        else:
            irregularities.remove(irr)
            pprint("Irregularity Image not found!", color=Color.YELLOW)
            pprint(f"    ImageURI: {img_path}", color=Color.YELLOW)
            pprint(f"    IrregularityID: {irr.irregularity_ID}", color=Color.YELLOW)

    return images


def get_irregularity_file(filepath: str) -> IrregularityFile:
    """
    Returns the Irregularity File from the specified path, or quits if the file is not found.

    Parameters
    ----------
    path : str
        The path of the Irregularity File.

    Raises
    ------
    FileNotFoundError
        If the file is not found.
    json.decoder.JSONDecodeError
        If the file is not a valid JSON file.

    Returns
    -------
    IrregularityFile
        The Irregularity File from the specified path.
    """
    try:
        irregularity_file = File(filepath, FileType.JSON).get_content()
        return IrregularityFile.from_json(irregularity_file)
    except FileNotFoundError:
        raise FileNotFoundError(f"{filepath} not found!")
    except json.decoder.JSONDecodeError:
        raise json.decoder.JSONDecodeError(f"{filepath} is not a valid JSON file!")

    
def load_model(path_to_model):
    model = keras.models.model_from_json(open(f"{path_to_model}.json", 'r').read())
    model.load_weights(f'{path_to_model}.h5')
    return model


def prepare_images_for_prediction(images: list[str]) -> np.ndarray:
    """
    Prepares the images for prediction by resizing them to 224x224 and scaling the BGR values in [0, 1].
    """
    
    images = np.stack(images, axis=0)
    # Scale BGR values in [0, 1]
    return images/255


def predict(model, images: list[str]):
    """
    Predicts the probability of interest for each image.
    """
    prediction = model.predict(images, batch_size=1)
    # We are interested in the first array of prediction probabilities
    probabilities = prediction[0]
    # Get the position where the highest probability is stored
    # It corresponds to the class index (0, 1, ...)
    y = np.argmax(probabilities, axis=1)
    return probabilities, y
