import os from pathlib import Path from typing import Dict, List import torch import whisper_timestamped from whisper_timestamped.make_subtitles import write_srt import pysrt from typeguard import typechecked import trs_class try: from common_utils.times import timeit except ModuleNotFoundError: from common_module.common_utils.times import timeit try: from common_utils.logger import create_logger except ModuleNotFoundError: from common_module.common_utils.logger import create_logger try: from common_utils.saves import save except ModuleNotFoundError: from common_module.common_utils.saves import save try: from common_utils.gpus_torch import pick_best_gpu except ModuleNotFoundError: from common_module.common_utils.gpus_torch import pick_best_gpu LOGGER = create_logger(__name__) @typechecked def whisper_trs( aud_path: str, diar_segments: List[Dict], device: str, compute_type: str, out_path: str, model_dir: str, dev_idx: int = 0, max_len: int = 200, ) -> Dict[str, list]: """ diar_segments: segments from mmc_aus. max_len: max length of a subtitle, in chars. Transcribes each segment from mmc_aus. """ my_trs = trs_class.WhisperTranscriber( os.path.join(model_dir, "whisper-large-v3"), device, # dev_idx, # compute_type, # , model_dir ) transcribe_options = {"task": "transcribe"} chunks = [] end_srt_fils = [] # read WAVs from mmc_aus mmc_aus_out_path = os.path.join(Path(out_path).parent, "mmc_aus") # loop over segments from mmc_aus for cnt_diar_segm, diar_segm in enumerate(diar_segments): LOGGER.debug(f"{cnt_diar_segm=}") LOGGER.debug(f'{diar_segm["start"]=}') LOGGER.debug(f'{diar_segm["end"]=}') if diar_segm["end"] - diar_segm["start"] < 1: # suppress short utterances (pyannote artifact) chunks.append( { "rel_start": transcr_d["segments"][0]["start"], "rel_end": transcr_d["segments"][-1]["end"], "text": None, "language": None, "speaker": None, } ) continue diar_segm_path = os.path.join(mmc_aus_out_path, f"split.{cnt_diar_segm}.wav") diar_audio = whisper_timestamped.load_audio(diar_segm_path) # get transcription transcr_d = my_trs( diar_audio, transcribe_options, ) chunks.append( { "rel_start": transcr_d["segments"][0]["start"], "rel_end": transcr_d["segments"][-1]["end"], "text": transcr_d["text"], "language": transcr_d["language"], "speaker": diar_segm["label"], } ) LOGGER.debug(f"{chunks[-1]=}") LOGGER.debug("writing subs...") # https://github.com/linto-ai/whisper-timestamped/issues/42#issuecomment-1443866219 tmp_srt_fil = os.path.join(out_path, f"tmp_subs.{cnt_diar_segm}.srt") end_srt_fils += [os.path.join(out_path, f"end_subs.{cnt_diar_segm}.srt")] write_srt( transcr_d["segments"], open( tmp_srt_fil, "w", encoding="utf-8", ), ) LOGGER.debug("upd'ing subs...") subs = pysrt.open(tmp_srt_fil, encoding="utf-8") LOGGER.debug("shifting subs...") subs.shift(seconds=diar_segm["start"]) LOGGER.debug("saving subs...") subs.save(end_srt_fils[-1], encoding="utf-8") # merge subs # https://unix.stackexchange.com/a/728438 offset = 0 with open( os.path.join(out_path, "end_subs.srt"), "w", encoding="utf-8" ) as end_subs_w: for subi in end_srt_fils: LOGGER.debug(f"reading {subi}...") with open(subi, "r", encoding="utf-8") as subir: for lin in subir: # check if line contains a single integer if lin.strip().isdigit(): # convert lin to int num = int(lin) num += offset # write num to output file end_subs_w.write(f"{num}\n") else: # write line end_subs_w.write(lin) # upd offset offset = num LOGGER.debug(f"{offset=}") return {"transcription": chunks} @timeit @typechecked def trs( aud_path: str, segments: List[Dict], out_path: str, model_dir: str, ) -> Dict[str, List]: """ aud_path: path of audio file. segments: mmc_aus output. out_path: path subs are saved to. model_dir: dir the model is saved to. Transcribes audio. Returns annotation. """ try: device = "cuda" if torch.cuda.is_available() else "cpu" LOGGER.debug("go to whisper_trs") compute_type = "float16" # pick best GPU dev_idx = pick_best_gpu() LOGGER.debug(f"using best GPU = {dev_idx}") return whisper_trs( aud_path, segments, device, compute_type, out_path, model_dir, dev_idx ) except RuntimeError: try: # TODO # try gpu 0? device = "cpu" compute_type = "int8" LOGGER.debug("using CPU") return whisper_trs( aud_path, segments, device, compute_type, out_path, model_dir ) except Exception as err: LOGGER.error(f"{err=}") return {"transcription": []} def trs_save( aud_path: str, segments: List[Dict], json_path: str, out_path: str, model_dir: str ): """ Transcribes audio. Saves annotation to JSON. """ annotation = trs(aud_path, segments, out_path, model_dir) save(annotation, json_path) return @typechecked def dl_trs_save( message_body: dict, out_json: str, out_path: str, base_dir: str, model_dir: str, ) -> bool: """ message_body: msg body. Has "FS" & "voices" as keys. out_json: JSON the output is saved to. out_path: unused input. base_dir: base dir. model_dir: dir the model is saved to. Downloads mmc_aus JSON. Transcribes audio. Returns 0 if success. """ # copies landmark ret_code = -1 aud_path = os.path.join( base_dir, message_body["programme"]["uid"], f'{message_body["programme"]["external_id"]}.wav', ) diar_out = message_body["programme"]["mmc_aus"]["segments"] # check diar_out # True if diar_out is a list # True if all dicts have keys ["start", "end", "label"] if isinstance(diar_out, list) and all( [ # True if all keys in a dict are ["start", "end", "label"] all(k in voice for k in ["start", "end", "label"]) for voice in diar_out ] ): # transcribe aud trs_save(aud_path, diar_out, out_json, out_path, model_dir) # success ret_code = 0 # copies landmark LOGGER.debug(f"return code: {ret_code}") if ret_code != 0: return False else: # success if ret_code = 0 return True """ def find_hallucinations(annotation): hall = [" Amara.org", "www.mesmerism.info", " QTSS", " ... "] for i, tr in enumerate(annotation["segments"]): for w in hall: if w in tr["text"]: LOGGER.debug(f"Hallucination removed: {tr['text']}") annotation["segments"].pop(i) break return annotation """