diariz_funs.py 6.18 KB
Newer Older
Mattia Bergagio's avatar
Mattia Bergagio committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
from pathlib import Path
import subprocess
from typing import Dict, List

import torch
from pyannote.audio import core, Pipeline
from typeguard import typechecked

import util_funs

try:
    from common_utils.logger import create_logger
except ModuleNotFoundError:
    from common_module.common_utils.logger import create_logger

try:
    from common_utils.saves import save
except ModuleNotFoundError:
    from common_module.common_utils.saves import save

try:
    from common_utils.gpus_torch import pick_best_gpu
except ModuleNotFoundError:
    from common_module.common_utils.gpus_torch import pick_best_gpu

LOGGER = create_logger(__name__)


def mk_pipeline(model_dir: str, conf_dir: str) -> Pipeline:
    """
    model_dir: dir the model is saved to.
    conf_dir: dir diar_conf.yaml is saved to.

    Makes mmc_aus pipeline.
    """
    hf_token = os.environ["HUGGINGFACE_TOKEN"]

    LOGGER.debug(f"pwd = {os.getcwd()}")

    # SpeechBrain_EncoderClassifier uses CACHE_DIR
    # https://github.com/pyannote/pyannote-audio/blob/a810a5a53ac6e241606fd4ec822ea842f4c0a9b5/pyannote/audio/pipelines/speaker_verification.py#L262
    # CACHE_DIR is set here:
    # https://github.com/pyannote/pyannote-audio/blob/a810a5a53ac6e241606fd4ec822ea842f4c0a9b5/pyannote/audio/core/model.py#L56
    os.environ["PYANNOTE_CACHE"] = model_dir
    LOGGER.debug(f'{os.environ["PYANNOTE_CACHE"]=}')
    LOGGER.debug(f"def: {core.model.CACHE_DIR=}")
    core.model.CACHE_DIR = os.environ["PYANNOTE_CACHE"]
    LOGGER.debug(f"upd: {core.model.CACHE_DIR=}")

    # path of pytorch_model.bin
    bin_path = os.path.join(model_dir, "segmentation", "pytorch_model.bin")

    # path of speechbrain/spkrec-ecapa-voxceleb
    voxceleb_path = os.path.join(model_dir, "speechbrain", "spkrec-ecapa-voxceleb")

    # copies ner
    LOGGER.debug("loading model from local dir...")
    # TODO
    # upgrade to pyannote/speaker-diarization-3.0
    # speaker_mmc_aus = Pipeline.from_pretrained(
    #     "pyannote/speaker-diarization@2.1", use_auth_token=hf_token
    # )

    # replace bin_path in YML
    util_funs.replace_str_in_fil(
        os.path.join(conf_dir, "diar_conf.yaml"),
        os.path.join(conf_dir, "tmp_diar_conf.yaml"),
        "bin_path",
        bin_path,
    )

    # replace voxceleb_path in YML
    util_funs.replace_str_in_fil(
        os.path.join(conf_dir, "tmp_diar_conf.yaml"),
        os.path.join(conf_dir, "new_diar_conf.yaml"),
        "voxceleb_path",
        voxceleb_path,
    )

    # print YML
    with open(os.path.join(conf_dir, "new_diar_conf.yaml"), "r") as ymlr:
        for ymllin in ymlr:
            LOGGER.debug(ymllin)

    speaker_mmc_aus = Pipeline.from_pretrained(
        os.path.join(conf_dir, "new_diar_conf.yaml"), use_auth_token=hf_token
    )

    # copies ner
    try:
        # pick best GPU
        dev_idx = pick_best_gpu()
        LOGGER.debug(f"using best GPU = {dev_idx}")
        device = torch.device(f"cuda:{dev_idx}")

        # push the pipeline to GPU
        speaker_mmc_aus = speaker_mmc_aus.to(device)

    except RuntimeError as gpu_err:
        LOGGER.debug(f"Unexpected {gpu_err=}, {type(gpu_err)=}")
        LOGGER.debug("using CPU")
        device = torch.device("cpu")

        # push the pipeline to CPU
        speaker_mmc_aus = speaker_mmc_aus.to(device)

    return speaker_mmc_aus


@typechecked
def diarize(
    audio: str, model_dir: str, conf_dir: str, out_path: str
) -> Dict[str, List]:
    """
    model_dir: dir the model is saved to.
    conf_dir: dir diar_conf.yaml is saved to.

    Diarizes audio.
    """
    diar_pipeline = mk_pipeline(model_dir, conf_dir)

    # num_speakers, min_speakers, max_speakers
    # can be set if they are known
    who_speaks_when = diar_pipeline(
        audio,
        num_speakers=None,
        min_speakers=None,
        max_speakers=None,
    )

    speakers = []
    for segment, _, speaker in who_speaks_when.itertracks(yield_label=True):
        speakers.append({"start": segment.start, "end": segment.end, "label": speaker})

        span = segment.end - segment.start
        diar_segm_path = os.path.join(out_path, f"split.{len(speakers) - 1}.wav")
        ffmpeg_split = [
            "ffmpeg",
            "-ss",
            str(segment.start),
            "-i",
            audio,
            "-t",
            str(span),
            "-c",
            "copy",
            diar_segm_path,
        ]
        try:
            subprocess.check_output(ffmpeg_split)
        except subprocess.CalledProcessError as err:
            raise RuntimeError(f"FFMPEG error {str(err)}")

    return {"voices": speakers}


@typechecked
def diarize_save(
    audio: str, out_json: str, out_path: str, model_dir: str, conf_dir: str
) -> None:
    """
    out_json: JSON the annotation is saved to.
    model_dir: dir the model is saved to.
    conf_dir: dir diar_conf.yaml is saved to.

    Diarizes audio.
    Saves output.
    """
    LOGGER.info(f"diarizing {audio}...")
    who_speaks_when = diarize(audio, model_dir, conf_dir, out_path)
    LOGGER.info(who_speaks_when)

    save(who_speaks_when, out_json)


@typechecked
def dl_diarize_save(
    message_body: dict,
    out_json: str,
    out_path: str,
    base_dir: str,
    model_dir: str,
) -> bool:
    """
    message_body: msg body.
    out_json: JSON the output is saved to.
    out_path: unused input.
    base_dir: base dir.
    model_dir: dir the model is saved to.

    Downloads audio.
    Diarizes audio.
    Saves output.
    Returns 0 if success.
    """
    # copies landmark
    ret_code = -1

    # access audio
    aud_path = os.path.join(
        base_dir,
        message_body["programme"]["uid"],
        message_body["programme"]["external_id"],
    )

    if Path(f"{aud_path}.wav").is_file():
        diarize_save(
            f"{aud_path}.wav",
            out_json,
            out_path,
            model_dir,
            message_body["programme"]["conf_dir"],
        )

        # success
        ret_code = 0
    else:
        # wav is not available
        # out_path = None
        LOGGER.error("Wav is not available")

        # failure
        ret_code = 2

    # copies landmark
    LOGGER.debug(f"return code: {ret_code}")

    if ret_code != 0:
        return False
    else:
        # success if ret_code = 0
        return True