import os
import tempfile
from uuid import uuid4

import numpy as np
from ffmpeg import FFmpeg

from mpai_cae_arp.audio import AudioWave, Noise
from mpai_cae_arp.files import File, FileType
from mpai_cae_arp.types.irregularity import Irregularity, IrregularityFile, Source
from mpai_cae_arp.time import frames_to_seconds, seconds_to_frames, seconds_to_string, time_to_seconds

TMP_CHANNELS_MAP = os.path.join(tempfile.gettempdir(), "mpai", "channels_map.json")

def calculate_offset(audio: AudioWave, video: AudioWave) -> float:
    """
    Calculates the offset between two audio files based on their cross-correlation.

    Parameters
    ----------
    audio : AudioWave
        The audio file to be used as reference.
    video : AudioWave
        The audio file to be used as target.
    
    Returns
    -------
    float
    """

    corr = np.correlate(audio.array, video.array, mode="full")
    lags = np.arange(-len(audio.array) + 1, len(video.array))
    lag_idx = np.argmax(np.abs(corr))

    return lags[lag_idx] / audio.samplerate


def get_audio_from_video(video_src: str):
    # calculate offset with ffmpeg
    # ffmpeg -i video.mov -acodec pcm_s16le -ac 2 audio.wav
    ffmpeg = (
        FFmpeg()
        .input(video_src)
        .output(
            "out.wav",
            {""}
        )
    )


def get_irregularities_from_audio(audio_src: AudioWave) -> list[Irregularity]:
    input_channels: list[AudioWave] = []

    if audio_src.channels > 1:
        for channel in range(audio_src.channels):
            input_channels.append(audio_src.get_channel(channel))
    else:
        input_channels.append(audio_src)

    channels_map = {}

    irreg_list: list[Irregularity] = []
    for idx, audio in enumerate(input_channels):
        for _, noise_list in audio.get_silence_slices([
            Noise("A", -50, -63),
            Noise("B", -63, -69),
            Noise("C", -69, -72)],
            length=500).items():
            for start, _ in noise_list:
                id = uuid4()
                irreg_list.append(
                    Irregularity(
                        irregularity_ID=id,
                        source=Source.AUDIO,
                        time_label= seconds_to_string(frames_to_seconds(start, audio.samplerate))
                    )
                )
                channels_map[str(id)] = idx

    File(TMP_CHANNELS_MAP, FileType.JSON).write_content(channels_map)

    return irreg_list


def create_irreg_file(audio_src: str, video_src: str) -> IrregularityFile:

    audio = AudioWave.from_file(audio_src, bufferize=True)
    


    offset = calculate_offset(audio, video_src)
    irregularities = get_irregularities_from_audio(audio)

    irregularities.sort(key=lambda x: time_to_seconds(x.time_label))
    
    return IrregularityFile(irregularities=irregularities, offset=offset)


def merge_irreg_files(
    file1: IrregularityFile,
    file2: IrregularityFile
) -> IrregularityFile:

    match file1.offset, file2.offset:
        case None, _:
            offset=file2.offset
        case _, None:
            offset=file1.offset
        case _, _:
            offset=max(file1.offset, file2.offset)

    irregularities = file1.irregularities + file2.irregularities
    irregularities.sort(key=lambda x: time_to_seconds(x.time_label))

    new_file = IrregularityFile(
        irregularities=irregularities, offset=offset)

    return new_file


def extract_audio_irregularities(
    audio_src: str,
    irreg_file: IrregularityFile,
    path: str
) -> IrregularityFile:

    channels_map = File(TMP_CHANNELS_MAP, FileType.JSON).get_content()
    os.makedirs(f"{path}/AudioBlocks", exist_ok=True)

    audio = AudioWave.from_file(audio_src, bufferize=True)
    for irreg in irreg_file.irregularities:
        if channels_map.get(str(irreg.irregularity_ID)) is None:
            audio[seconds_to_frames(
                        time_to_seconds(irreg.time_label), audio.samplerate
                    ):seconds_to_frames(
                        time_to_seconds(irreg.time_label), audio.samplerate)+audio.samplerate//2]\
                .save(f"{path}/AudioBlocks/{irreg.irregularity_ID}.wav")
        else:
            audio.get_channel(channels_map[str(irreg.irregularity_ID)])[
                    seconds_to_frames(
                        time_to_seconds(irreg.time_label), audio.samplerate
                    ):seconds_to_frames(
                        time_to_seconds(irreg.time_label), audio.samplerate)+audio.samplerate//2]\
                .save(f"{path}/AudioBlocks/{irreg.irregularity_ID}.wav")
        irreg.audio_block_URI = f"{path}/AudioBlocks/{irreg.irregularity_ID}.wav"
    os.remove(TMP_CHANNELS_MAP)

    return irreg_file
