import sys import os from contextlib import contextmanager from typing import Any, Dict import numpy as np import whisper_timestamped from typeguard import typechecked try: from common_utils.logger import create_logger except ModuleNotFoundError: from common_module.common_utils.logger import create_logger LOGGER = create_logger(__name__) @contextmanager def suppress_stdout(): # Auxiliary function to suppress Whisper logs (it is quite verbose) # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/ with open(os.devnull, "w") as devnull: old_stdout = sys.stdout sys.stdout = devnull try: yield finally: sys.stdout = old_stdout @typechecked class WhisperTranscriber: def __init__( self, model: str, device: str, # dev_idx: int, # compute_type: str, # model_dir: str, ): """ model: name of the model or path to the model. Examples: - OpenAI-Whisper identifier: "large-v3", "medium.en", ... - HuggingFace identifier: "openai/whisper-large-v3", "distil-whisper/distil-large-v2", ... - File name: "path/to/model.pt", "path/to/model.ckpt", "path/to/model.bin" - Folder name: "path/to/folder". The folder must contain either "pytorch_model.bin", "model.safetensors", or sharded versions of those, or "whisper.ckpt". device : device to use. If None, use CUDA if there is a GPU available, otherwise CPU. """ self.model = whisper_timestamped.load_model( model, device, # device_index=dev_idx, # compute_type=compute_type, # download_root=model_dir, ) self._buffer = "" def transcribe( self, waveform: np.ndarray, options: Dict[str, Any] ) -> Dict[str, Any]: """ Transcribes audio using Whisper """ LOGGER.info(f"Transcribing...") # Pad/trim audio to fit 30 s as required by Whisper # tweaked_audio = whisper_timestamped.pad_or_trim(waveform) # Transcribe the given audio while suppressing logs # Whisper models can "hallucinate" text when given a segment w/o speech. # This can be avoided by running VAD and gluing speech segments together # before transcribing with suppress_stdout(): transcription = whisper_timestamped.transcribe( self.model, waveform, # tweaked_audio, beam_size=5, best_of=5, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), vad="silero:v3.1", # use past transcriptions to condition the model initial_prompt=self._buffer, verbose=True, **options, ) return transcription def __call__(self, waveform: np.ndarray, options: Dict[str, Any]) -> Dict[str, Any]: # transcribe transcription = self.transcribe(waveform, options) # upd transcription buffer self._buffer += transcription["text"] return transcription