import os import operator from pathlib import Path from typing import Dict, List import numpy as np from speechbrain.inference.speaker import SpeakerRecognition from typeguard import typechecked try: from common_utils.times import timeit except ModuleNotFoundError: from common_module.common_utils.times import timeit try: from common_utils.logger import create_logger except ModuleNotFoundError: from common_module.common_utils.logger import create_logger try: from common_utils.saves import save except ModuleNotFoundError: from common_module.common_utils.saves import save try: from common_utils.gpus_torch import pick_best_gpu except ModuleNotFoundError: from common_module.common_utils.gpus_torch import pick_best_gpu LOGGER = create_logger(__name__) @typechecked def verify_speaker( dataset: List[str], out_path: str, model_dir: str ) -> Dict[str, list]: """ dataset: dataset path. The dataset is a list of dirs containing WAVs. Verifies the speaker. """ spkrecs = [] # path of speechbrain/spkrec-ecapa-voxceleb voxceleb_path = os.path.join(model_dir, "speechbrain", "spkrec-ecapa-voxceleb") try: LOGGER.debug("trying to use GPU") # pick best GPU dev_idx = pick_best_gpu() device = f"cuda:{dev_idx}" LOGGER.debug(f"using best GPU = {dev_idx}") # https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb verification = SpeakerRecognition.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir=voxceleb_path, run_opts={"device": device}, ) except RuntimeError: try: device = "cpu" LOGGER.debug("using CPU") # https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb verification = SpeakerRecognition.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir=voxceleb_path, ) except Exception as err: LOGGER.error(f"{err=}") return {"mmc_sir": spkrecs} # read WAVs from mmc_aus mmc_asr_out_path = os.path.join(Path(out_path).parent, "mmc_aus") # read WAVs from dataset # dataset is a list of dirs # TODO # read AI_FW_DIR from input msg? dataset_dir = os.path.join(os.environ["AI_FW_DIR"], *dataset) LOGGER.debug(f"{dataset_dir=}") LOGGER.debug("dataset: mapping WAV to ID...") # map: path of WAV in dataset -> speaker ID dataset_map = {"wav": [], "id": []} for root, _, filnames in os.walk(dataset_dir): for filname in filnames: if filname.endswith(".wav"): full_path = os.path.join(root, filname) tmp_path = full_path # get immediate subdir below dataset_dir # this subdir is named after speaker ID while tmp_path != dataset_dir: old_tmp_path = tmp_path tmp_path = str(Path(tmp_path).parent) # get speaker ID using os.path.basename dataset_map["wav"] += [full_path] dataset_map["id"] += [os.path.basename(old_tmp_path)] LOGGER.debug(f"{len(dataset_map['wav'])=}") for mmc_asr_out_fil in os.listdir(mmc_asr_out_path): mmc_asr_out = os.path.join(mmc_asr_out_path, mmc_asr_out_fil) if mmc_asr_out_fil.endswith(".wav"): LOGGER.debug(f"proc'ing {mmc_asr_out}...") # compute scores if device.startswith("cuda"): # torch.Tensor -> np.array scores = [ verification.verify_files(mmc_asr_out, dataset_wav)[0].cpu().numpy() for dataset_wav in dataset_map["wav"] ] elif device == "cpu": scores = [ verification.verify_files(mmc_asr_out, dataset_wav)[0] for dataset_wav in dataset_map["wav"] ] score_argmax = np.argmax(scores) LOGGER.debug(f"{score_argmax=}") spkrecs.append( { # read segm no. from filename "segment": int(mmc_asr_out_fil.split("split.")[1].split(".wav")[0]), "speaker_id": dataset_map["id"][score_argmax], "score": float(np.max(scores)), } ) return {"mmc_sir": sorted(spkrecs, key=operator.itemgetter("segment"))} @typechecked def verify_speaker_save( dataset: List[str], json_path: str, out_path: str, model_dir: str ): """ dataset: list of dirs. WAVs in dataset are saved here. Recogs speakers. Saves output to JSON. """ annotation = verify_speaker(dataset, out_path, model_dir) save(annotation, json_path) return @typechecked def spkrec_save( message_body: dict, out_json: str, out_path: str, base_dir: str, model_dir: str, ) -> bool: """ out_json: JSON the output is saved to. out_path: unused input. base_dir: base dir. model_dir: dir the model is saved to. Recogs speakers. Returns 0 if success. """ # copies landmark ret_code = -1 diar_out = message_body["programme"]["mmc_aus"]["segments"] # check diar_out # True if diar_out is a list # True if all dicts have keys ["start", "end", "label"] if isinstance(diar_out, list) and all( [ # True if all keys in a dict are ["start", "end", "label"] all(k in voice for k in ["start", "end", "label"]) for voice in diar_out ] ): # recog speakers verify_speaker_save( message_body["programme"]["spkrec_dataset"], out_json, out_path, model_dir, ) # success ret_code = 0 # copies landmark LOGGER.debug(f"return code: {ret_code}") if ret_code != 0: return False else: # success if ret_code = 0 return True