import os
from pathlib import Path
from typing import Dict, List

import torch
import whisper_timestamped
from whisper_timestamped.make_subtitles import write_srt
import pysrt
from typeguard import typechecked

import trs_class

try:
    from common_utils.times import timeit
except ModuleNotFoundError:
    from common_module.common_utils.times import timeit

try:
    from common_utils.logger import create_logger
except ModuleNotFoundError:
    from common_module.common_utils.logger import create_logger

try:
    from common_utils.saves import save
except ModuleNotFoundError:
    from common_module.common_utils.saves import save

try:
    from common_utils.gpus_torch import pick_best_gpu
except ModuleNotFoundError:
    from common_module.common_utils.gpus_torch import pick_best_gpu

LOGGER = create_logger(__name__)


@typechecked
def whisper_trs(
    aud_path: str,
    diar_segments: List[Dict],
    device: str,
    compute_type: str,
    out_path: str,
    model_dir: str,
    dev_idx: int = 0,
    max_len: int = 200,
) -> Dict[str, list]:
    """
    diar_segments: segments from mmc_aus.
    max_len: max length of a subtitle, in chars.

    Transcribes each segment from mmc_aus.
    """
    my_trs = trs_class.WhisperTranscriber(
        os.path.join(model_dir, "whisper-large-v3"),
        device,
        # dev_idx,
        # compute_type,  # , model_dir
    )

    transcribe_options = {"task": "transcribe"}

    chunks = []
    end_srt_fils = []

    # read WAVs from mmc_aus
    mmc_aus_out_path = os.path.join(Path(out_path).parent, "mmc_aus")

    # loop over segments from mmc_aus
    for cnt_diar_segm, diar_segm in enumerate(diar_segments):
        LOGGER.debug(f"{cnt_diar_segm=}")
        LOGGER.debug(f'{diar_segm["start"]=}')
        LOGGER.debug(f'{diar_segm["end"]=}')

        if diar_segm["end"] - diar_segm["start"] < 1:
            # suppress short utterances (pyannote artifact)
            chunks.append(
                {
                    "rel_start": transcr_d["segments"][0]["start"],
                    "rel_end": transcr_d["segments"][-1]["end"],
                    "text": None,
                    "language": None,
                    "speaker": None,
                }
            )
            continue

        diar_segm_path = os.path.join(mmc_aus_out_path, f"split.{cnt_diar_segm}.wav")
        diar_audio = whisper_timestamped.load_audio(diar_segm_path)

        # get transcription
        transcr_d = my_trs(
            diar_audio,
            transcribe_options,
        )

        chunks.append(
            {
                "rel_start": transcr_d["segments"][0]["start"],
                "rel_end": transcr_d["segments"][-1]["end"],
                "text": transcr_d["text"],
                "language": transcr_d["language"],
                "speaker": diar_segm["label"],
            }
        )

        LOGGER.debug(f"{chunks[-1]=}")
        LOGGER.debug("writing subs...")
        # https://github.com/linto-ai/whisper-timestamped/issues/42#issuecomment-1443866219
        tmp_srt_fil = os.path.join(out_path, f"tmp_subs.{cnt_diar_segm}.srt")
        end_srt_fils += [os.path.join(out_path, f"end_subs.{cnt_diar_segm}.srt")]
        write_srt(
            transcr_d["segments"],
            open(
                tmp_srt_fil,
                "w",
                encoding="utf-8",
            ),
        )

        LOGGER.debug("upd'ing subs...")
        subs = pysrt.open(tmp_srt_fil, encoding="utf-8")

        LOGGER.debug("shifting subs...")
        subs.shift(seconds=diar_segm["start"])

        LOGGER.debug("saving subs...")
        subs.save(end_srt_fils[-1], encoding="utf-8")

    # merge subs
    # https://unix.stackexchange.com/a/728438
    offset = 0

    with open(
        os.path.join(out_path, "end_subs.srt"), "w", encoding="utf-8"
    ) as end_subs_w:
        for subi in end_srt_fils:
            LOGGER.debug(f"reading {subi}...")
            with open(subi, "r", encoding="utf-8") as subir:
                for lin in subir:
                    # check if line contains a single integer
                    if lin.strip().isdigit():
                        # convert lin to int
                        num = int(lin)
                        num += offset

                        # write num to output file
                        end_subs_w.write(f"{num}\n")
                    else:
                        # write line
                        end_subs_w.write(lin)

            # upd offset
            offset = num
            LOGGER.debug(f"{offset=}")

    return {"transcription": chunks}


@timeit
@typechecked
def trs(
    aud_path: str,
    segments: List[Dict],
    out_path: str,
    model_dir: str,
) -> Dict[str, List]:
    """
    aud_path: path of audio file.
    segments: mmc_aus output.
    out_path: path subs are saved to.
    model_dir: dir the model is saved to.

    Transcribes audio.
    Returns annotation.
    """
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        LOGGER.debug("go to whisper_trs")
        compute_type = "float16"

        # pick best GPU
        dev_idx = pick_best_gpu()
        LOGGER.debug(f"using best GPU = {dev_idx}")

        return whisper_trs(
            aud_path, segments, device, compute_type, out_path, model_dir, dev_idx
        )

    except RuntimeError:
        try:
            # TODO
            # try gpu 0?
            device = "cpu"
            compute_type = "int8"

            LOGGER.debug("using CPU")

            return whisper_trs(
                aud_path, segments, device, compute_type, out_path, model_dir
            )

        except Exception as err:
            LOGGER.error(f"{err=}")
            return {"transcription": []}


def trs_save(
    aud_path: str, segments: List[Dict], json_path: str, out_path: str, model_dir: str
):
    """
    Transcribes audio.
    Saves annotation to JSON.
    """
    annotation = trs(aud_path, segments, out_path, model_dir)
    save(annotation, json_path)
    return


@typechecked
def dl_trs_save(
    message_body: dict,
    out_json: str,
    out_path: str,
    base_dir: str,
    model_dir: str,
) -> bool:
    """
    message_body: msg body. Has "FS" & "voices" as keys.
    out_json: JSON the output is saved to.
    out_path: unused input.
    base_dir: base dir.
    model_dir: dir the model is saved to.

    Downloads mmc_aus JSON.
    Transcribes audio.
    Returns 0 if success.
    """
    # copies landmark
    ret_code = -1

    aud_path = os.path.join(
        base_dir,
        message_body["programme"]["uid"],
        f'{message_body["programme"]["external_id"]}.wav',
    )
    diar_out = message_body["programme"]["mmc_aus"]["segments"]

    # check diar_out
    # True if diar_out is a list
    # True if all dicts have keys ["start", "end", "label"]
    if isinstance(diar_out, list) and all(
        [
            # True if all keys in a dict are ["start", "end", "label"]
            all(k in voice for k in ["start", "end", "label"])
            for voice in diar_out
        ]
    ):

        # transcribe aud
        trs_save(aud_path, diar_out, out_json, out_path, model_dir)

        # success
        ret_code = 0

    # copies landmark
    LOGGER.debug(f"return code: {ret_code}")

    if ret_code != 0:
        return False
    else:
        # success if ret_code = 0
        return True


"""
def find_hallucinations(annotation):
    hall = [" Amara.org", "www.mesmerism.info", " QTSS", " ... "]

    for i, tr in enumerate(annotation["segments"]):
        for w in hall:
            if w in tr["text"]:
                LOGGER.debug(f"Hallucination removed: {tr['text']}")
                annotation["segments"].pop(i)
                break
    return annotation
"""
