import os
import operator
from pathlib import Path
from typing import Dict, List

import numpy as np
from speechbrain.inference.speaker import SpeakerRecognition
from typeguard import typechecked

try:
    from common_utils.times import timeit
except ModuleNotFoundError:
    from common_module.common_utils.times import timeit

try:
    from common_utils.logger import create_logger
except ModuleNotFoundError:
    from common_module.common_utils.logger import create_logger

try:
    from common_utils.saves import save
except ModuleNotFoundError:
    from common_module.common_utils.saves import save

try:
    from common_utils.gpus_torch import pick_best_gpu
except ModuleNotFoundError:
    from common_module.common_utils.gpus_torch import pick_best_gpu

LOGGER = create_logger(__name__)


@typechecked
def verify_speaker(
    dataset: List[str], out_path: str, model_dir: str
) -> Dict[str, list]:
    """
    dataset: dataset path. The dataset is a list of dirs containing WAVs.

    Verifies the speaker.
    """
    spkrecs = []

    # path of speechbrain/spkrec-ecapa-voxceleb
    voxceleb_path = os.path.join(model_dir, "speechbrain", "spkrec-ecapa-voxceleb")

    try:
        LOGGER.debug("trying to use GPU")

        # pick best GPU
        dev_idx = pick_best_gpu()
        device = f"cuda:{dev_idx}"
        LOGGER.debug(f"using best GPU = {dev_idx}")

        # https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb
        verification = SpeakerRecognition.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            savedir=voxceleb_path,
            run_opts={"device": device},
        )

    except RuntimeError:
        try:
            device = "cpu"
            LOGGER.debug("using CPU")

            # https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb
            verification = SpeakerRecognition.from_hparams(
                source="speechbrain/spkrec-ecapa-voxceleb",
                savedir=voxceleb_path,
            )

        except Exception as err:
            LOGGER.error(f"{err=}")
            return {"mmc_sir": spkrecs}

    # read WAVs from mmc_aus
    mmc_asr_out_path = os.path.join(Path(out_path).parent, "mmc_aus")

    # read WAVs from dataset
    # dataset is a list of dirs

    # TODO
    # read AI_FW_DIR from input msg?
    dataset_dir = os.path.join(os.environ["AI_FW_DIR"], *dataset)
    LOGGER.debug(f"{dataset_dir=}")

    LOGGER.debug("dataset: mapping WAV to ID...")
    # map: path of WAV in dataset -> speaker ID
    dataset_map = {"wav": [], "id": []}
    for root, _, filnames in os.walk(dataset_dir):
        for filname in filnames:
            if filname.endswith(".wav"):
                full_path = os.path.join(root, filname)
                tmp_path = full_path
                # get immediate subdir below dataset_dir
                # this subdir is named after speaker ID
                while tmp_path != dataset_dir:
                    old_tmp_path = tmp_path
                    tmp_path = str(Path(tmp_path).parent)

                # get speaker ID using os.path.basename
                dataset_map["wav"] += [full_path]
                dataset_map["id"] += [os.path.basename(old_tmp_path)]

    LOGGER.debug(f"{len(dataset_map['wav'])=}")

    for mmc_asr_out_fil in os.listdir(mmc_asr_out_path):
        mmc_asr_out = os.path.join(mmc_asr_out_path, mmc_asr_out_fil)
        if mmc_asr_out_fil.endswith(".wav"):
            LOGGER.debug(f"proc'ing {mmc_asr_out}...")

            # compute scores
            if device.startswith("cuda"):
                # torch.Tensor -> np.array
                scores = [
                    verification.verify_files(mmc_asr_out, dataset_wav)[0].cpu().numpy()
                    for dataset_wav in dataset_map["wav"]
                ]

            elif device == "cpu":
                scores = [
                    verification.verify_files(mmc_asr_out, dataset_wav)[0]
                    for dataset_wav in dataset_map["wav"]
                ]

            score_argmax = np.argmax(scores)

            LOGGER.debug(f"{score_argmax=}")
            spkrecs.append(
                {
                    # read segm no. from filename
                    "segment": int(mmc_asr_out_fil.split("split.")[1].split(".wav")[0]),
                    "speaker_id": dataset_map["id"][score_argmax],
                    "score": float(np.max(scores)),
                }
            )

    return {"mmc_sir": sorted(spkrecs, key=operator.itemgetter("segment"))}


@typechecked
def verify_speaker_save(
    dataset: List[str], json_path: str, out_path: str, model_dir: str
):
    """
    dataset: list of dirs. WAVs in dataset are saved here.

    Recogs speakers.
    Saves output to JSON.
    """
    annotation = verify_speaker(dataset, out_path, model_dir)
    save(annotation, json_path)
    return


@typechecked
def spkrec_save(
    message_body: dict,
    out_json: str,
    out_path: str,
    base_dir: str,
    model_dir: str,
) -> bool:
    """
    out_json: JSON the output is saved to.
    out_path: unused input.
    base_dir: base dir.
    model_dir: dir the model is saved to.

    Recogs speakers.
    Returns 0 if success.
    """
    # copies landmark
    ret_code = -1

    diar_out = message_body["programme"]["mmc_aus"]["segments"]

    # check diar_out
    # True if diar_out is a list
    # True if all dicts have keys ["start", "end", "label"]
    if isinstance(diar_out, list) and all(
        [
            # True if all keys in a dict are ["start", "end", "label"]
            all(k in voice for k in ["start", "end", "label"])
            for voice in diar_out
        ]
    ):

        # recog speakers
        verify_speaker_save(
            message_body["programme"]["spkrec_dataset"],
            out_json,
            out_path,
            model_dir,
        )

        # success
        ret_code = 0

    # copies landmark
    LOGGER.debug(f"return code: {ret_code}")

    if ret_code != 0:
        return False
    else:
        # success if ret_code = 0
        return True
