trs_class.py 3.15 KB
Newer Older
Mattia Bergagio's avatar
Mattia Bergagio committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import sys
import os
from contextlib import contextmanager
from typing import Any, Dict

import numpy as np
import whisper_timestamped
from typeguard import typechecked

try:
    from common_utils.logger import create_logger
except ModuleNotFoundError:
    from common_module.common_utils.logger import create_logger

LOGGER = create_logger(__name__)


@contextmanager
def suppress_stdout():
    # Auxiliary function to suppress Whisper logs (it is quite verbose)
    # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


@typechecked
class WhisperTranscriber:
    def __init__(
        self,
        model: str,
        device: str,
        # dev_idx: int,
        # compute_type: str,
        # model_dir: str,
    ):
        """
        model: name of the model or path to the model.
        Examples:
        - OpenAI-Whisper identifier: "large-v3", "medium.en", ...
        - HuggingFace identifier: "openai/whisper-large-v3", "distil-whisper/distil-large-v2", ...
        - File name: "path/to/model.pt", "path/to/model.ckpt", "path/to/model.bin"
        - Folder name: "path/to/folder".
        The folder must contain either "pytorch_model.bin", "model.safetensors",
        or sharded versions of those, or "whisper.ckpt".
        device : device to use. If None, use CUDA if there is a GPU available, otherwise CPU.
        """
        self.model = whisper_timestamped.load_model(
            model,
            device,
            # device_index=dev_idx,
            # compute_type=compute_type,
            # download_root=model_dir,
        )

        self._buffer = ""

    def transcribe(
        self, waveform: np.ndarray, options: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Transcribes audio using Whisper
        """
        LOGGER.info(f"Transcribing...")

        # Pad/trim audio to fit 30 s as required by Whisper
        # tweaked_audio = whisper_timestamped.pad_or_trim(waveform)

        # Transcribe the given audio while suppressing logs
        # Whisper models can "hallucinate" text when given a segment w/o speech.
        # This can be avoided by running VAD and gluing speech segments together
        # before transcribing
        with suppress_stdout():
            transcription = whisper_timestamped.transcribe(
                self.model,
                waveform,  # tweaked_audio,
                beam_size=5,
                best_of=5,
                temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
                vad="silero:v3.1",
                # use past transcriptions to condition the model
                initial_prompt=self._buffer,
                verbose=True,
                **options,
            )

        return transcription

    def __call__(self, waveform: np.ndarray, options: Dict[str, Any]) -> Dict[str, Any]:
        # transcribe
        transcription = self.transcribe(waveform, options)

        # upd transcription buffer
        self._buffer += transcription["text"]

        return transcription