import os
import tempfile
from uuid import uuid4
import numpy as np

from mpai_cae_arp.audio import AudioWave, Noise
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.types.irregularity import Irregularity, IrregularityFile, Source
from mpai_cae_arp.time import frames_to_seconds, seconds_to_frames

temp_dir = tempfile.gettempdir()
TMP_CHANNELS_MAP = os.path.join(temp_dir, "channels_map.json")

def calculate_offset(audio: AudioWave, video: AudioWave) -> float:
    """
    Calculates the offset between two audio files based on their cross-correlation.

    Parameters
    ----------
    audio : AudioWave
        The audio file to be used as reference.
    video : AudioWave
        The audio file to be used as target.
    
    Returns
    -------
    float
    """

    corr = np.correlate(audio.array, video.array, mode="full")
    lags = np.arange(-len(audio.array) + 1, len(video.array))
    lag_idx = np.argmax(np.abs(corr))

    return lags[lag_idx] / audio.samplerate


def get_irregularities_from_audio(audio_src: AudioWave) -> list[Irregularity]:
    input_channels: list[AudioWave] = []
    for channel in audio_src.channels:
        input_channels.append(audio_src.get_channel(channel))

    channels_map = {}

    irreg_list: list[Irregularity] = []
    for idx, audio in enumerate(input_channels):
        for _, noise_list in audio.get_silence_slices([
            Noise("A", -50, -63),
            Noise("B", -63, -69),
            Noise("C", -69, -72)],
            length=500).items():
            for start, _ in noise_list:
                id = uuid4()
                irreg_list.append(
                    Irregularity(
                        uuid=id,
                        source=Source.AUDIO,
                        time_label=frames_to_seconds(start, audio.samplerate)
                    )
                )
                channels_map[id] = idx

    File(TMP_CHANNELS_MAP, FileType.JSON).write_content(channels_map)

    return irreg_list


def create_irreg_file(audio_src: str, video_src: str) -> IrregularityFile:

    audio = AudioWave.from_file(audio_src, bufferize=True)
    
    offset = calculate_offset(audio, video_src)
    return IrregularityFile(get_irregularities_from_audio(audio), offset=offset)


def merge_irreg_files(
    file1: IrregularityFile,
    file2: IrregularityFile) -> IrregularityFile:
    new_file = IrregularityFile(
        irregularities=file1.irregularities + file2.irregularities,
        offset=np.argmax([file1.offset, file2.offset]))

    new_file.irregularities.sort(key=lambda x: x.time_label)

    return new_file


def extract_audio_irregularities(
    audio: AudioWave,
    irreg_file: IrregularityFile,
    path: str) -> None:
    channels_map = File(TMP_CHANNELS_MAP, FileType.JSON).get_content()
    for irreg in irreg_file.irregularities:
        if irreg.source == Source.AUDIO:
            chunk = audio.get_channel(channels_map[irreg.irregularity_ID])[
                seconds_to_frames(
                    irreg.time_label, audio.samplerate
                ):seconds_to_frames(
                    irreg.time_label, audio.samplerate)+audio.samplerate//2]
            chunk.save(f"{path}/AudioBlocks/{irreg.irregularity_ID}.wav")
    os.remove(TMP_CHANNELS_MAP)
