Commit 27ebaea9 authored by Mattia Bergagio's avatar Mattia Bergagio
Browse files

Initial commit

parents
Pipeline #43 canceled with stages
# Ubuntu 20.04
# CUDA 11.6.2
# cuDNN 8
FROM nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
ENV TZ='Europe/Rome'
ENV BASE_FOLDER='/mmc_asr'
ENV LOGS_FOLDER='/LOGS'
ENV APP_USER='devuser'
ARG GIT_NAME
ARG GIT_TOKEN
RUN \
apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
python3.8 \
python3-pip \
tzdata \
ffmpeg \
git \
# TODO Other packets common to all images go in here
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/* && \
ln -snf "/usr/share/zoneinfo/$TZ" '/etc/localtime' && \
echo "$TZ" > '/etc/timezone'
RUN useradd -m "$APP_USER" && \
mkdir -p "$BASE_FOLDER" && \
chown -R "${APP_USER}:${APP_USER}" "$BASE_FOLDER" && \
mkdir -p "$LOGS_FOLDER" && \
chown -R "${APP_USER}:${APP_USER}" "$LOGS_FOLDER"
USER "$APP_USER"
WORKDIR "$BASE_FOLDER"
ENV PATH="${PATH}:/root/.local/bin:/home/${APP_USER}/.local/bin"
ENV PYTHONPATH="${PYTHONPATH}:${BASE_FOLDER}"
COPY --chown="${APP_USER}:${APP_USER}" /requirements.txt ./requirements.txt
RUN python3 -m pip install --no-cache-dir -r requirements.txt
WORKDIR "$BASE_FOLDER"/src
RUN git clone https://${GIT_NAME}:${GIT_TOKEN}@gitlab.eurixgroup.com/mpai/common_module.git
RUN pwd
RUN ls
WORKDIR ..
COPY --chown="${APP_USER}:${APP_USER}" /src ./src
CMD ["python3.8", "src/main.py"]
```
cd $PATH_SHARED
mkdir models
cd models
mkdir mmc_asr
cd mmc_asr
git lfs install
git clone https://huggingface.co/openai/whisper-large-v3
```
import os
from pathlib import Path
from typing import Dict, List
import torch
import whisper_timestamped
from whisper_timestamped.make_subtitles import write_srt
import pysrt
from typeguard import typechecked
import trs_class
try:
from common_utils.times import timeit
except ModuleNotFoundError:
from common_module.common_utils.times import timeit
try:
from common_utils.logger import create_logger
except ModuleNotFoundError:
from common_module.common_utils.logger import create_logger
try:
from common_utils.saves import save
except ModuleNotFoundError:
from common_module.common_utils.saves import save
try:
from common_utils.gpus_torch import pick_best_gpu
except ModuleNotFoundError:
from common_module.common_utils.gpus_torch import pick_best_gpu
LOGGER = create_logger(__name__)
@typechecked
def whisper_trs(
aud_path: str,
diar_segments: List[Dict],
device: str,
compute_type: str,
out_path: str,
model_dir: str,
dev_idx: int = 0,
max_len: int = 200,
) -> Dict[str, list]:
"""
diar_segments: segments from mmc_aus.
max_len: max length of a subtitle, in chars.
Transcribes each segment from mmc_aus.
"""
my_trs = trs_class.WhisperTranscriber(
os.path.join(model_dir, "whisper-large-v3"),
device,
# dev_idx,
# compute_type, # , model_dir
)
transcribe_options = {"task": "transcribe"}
chunks = []
end_srt_fils = []
# read WAVs from mmc_aus
mmc_aus_out_path = os.path.join(Path(out_path).parent, "mmc_aus")
# loop over segments from mmc_aus
for cnt_diar_segm, diar_segm in enumerate(diar_segments):
LOGGER.debug(f"{cnt_diar_segm=}")
LOGGER.debug(f'{diar_segm["start"]=}')
LOGGER.debug(f'{diar_segm["end"]=}')
if diar_segm["end"] - diar_segm["start"] < 1:
# suppress short utterances (pyannote artifact)
chunks.append(
{
"rel_start": transcr_d["segments"][0]["start"],
"rel_end": transcr_d["segments"][-1]["end"],
"text": None,
"language": None,
"speaker": None,
}
)
continue
diar_segm_path = os.path.join(mmc_aus_out_path, f"split.{cnt_diar_segm}.wav")
diar_audio = whisper_timestamped.load_audio(diar_segm_path)
# get transcription
transcr_d = my_trs(
diar_audio,
transcribe_options,
)
chunks.append(
{
"rel_start": transcr_d["segments"][0]["start"],
"rel_end": transcr_d["segments"][-1]["end"],
"text": transcr_d["text"],
"language": transcr_d["language"],
"speaker": diar_segm["label"],
}
)
LOGGER.debug(f"{chunks[-1]=}")
LOGGER.debug("writing subs...")
# https://github.com/linto-ai/whisper-timestamped/issues/42#issuecomment-1443866219
tmp_srt_fil = os.path.join(out_path, f"tmp_subs.{cnt_diar_segm}.srt")
end_srt_fils += [os.path.join(out_path, f"end_subs.{cnt_diar_segm}.srt")]
write_srt(
transcr_d["segments"],
open(
tmp_srt_fil,
"w",
encoding="utf-8",
),
)
LOGGER.debug("upd'ing subs...")
subs = pysrt.open(tmp_srt_fil, encoding="utf-8")
LOGGER.debug("shifting subs...")
subs.shift(seconds=diar_segm["start"])
LOGGER.debug("saving subs...")
subs.save(end_srt_fils[-1], encoding="utf-8")
# merge subs
# https://unix.stackexchange.com/a/728438
offset = 0
with open(
os.path.join(out_path, "end_subs.srt"), "w", encoding="utf-8"
) as end_subs_w:
for subi in end_srt_fils:
LOGGER.debug(f"reading {subi}...")
with open(subi, "r", encoding="utf-8") as subir:
for lin in subir:
# check if line contains a single integer
if lin.strip().isdigit():
# convert lin to int
num = int(lin)
num += offset
# write num to output file
end_subs_w.write(f"{num}\n")
else:
# write line
end_subs_w.write(lin)
# upd offset
offset = num
LOGGER.debug(f"{offset=}")
return {"transcription": chunks}
@timeit
@typechecked
def trs(
aud_path: str,
segments: List[Dict],
out_path: str,
model_dir: str,
) -> Dict[str, List]:
"""
aud_path: path of audio file.
segments: mmc_aus output.
out_path: path subs are saved to.
model_dir: dir the model is saved to.
Transcribes audio.
Returns annotation.
"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
LOGGER.debug("go to whisper_trs")
compute_type = "float16"
# pick best GPU
dev_idx = pick_best_gpu()
LOGGER.debug(f"using best GPU = {dev_idx}")
return whisper_trs(
aud_path, segments, device, compute_type, out_path, model_dir, dev_idx
)
except RuntimeError:
try:
# TODO
# try gpu 0?
device = "cpu"
compute_type = "int8"
LOGGER.debug("using CPU")
return whisper_trs(
aud_path, segments, device, compute_type, out_path, model_dir
)
except Exception as err:
LOGGER.error(f"{err=}")
return {"transcription": []}
def trs_save(
aud_path: str, segments: List[Dict], json_path: str, out_path: str, model_dir: str
):
"""
Transcribes audio.
Saves annotation to JSON.
"""
annotation = trs(aud_path, segments, out_path, model_dir)
save(annotation, json_path)
return
@typechecked
def dl_trs_save(
message_body: dict,
out_json: str,
out_path: str,
base_dir: str,
model_dir: str,
) -> bool:
"""
message_body: msg body. Has "FS" & "voices" as keys.
out_json: JSON the output is saved to.
out_path: unused input.
base_dir: base dir.
model_dir: dir the model is saved to.
Downloads mmc_aus JSON.
Transcribes audio.
Returns 0 if success.
"""
# copies landmark
ret_code = -1
aud_path = os.path.join(
base_dir,
message_body["programme"]["uid"],
f'{message_body["programme"]["external_id"]}.wav',
)
diar_out = message_body["programme"]["mmc_aus"]["segments"]
# check diar_out
# True if diar_out is a list
# True if all dicts have keys ["start", "end", "label"]
if isinstance(diar_out, list) and all(
[
# True if all keys in a dict are ["start", "end", "label"]
all(k in voice for k in ["start", "end", "label"])
for voice in diar_out
]
):
# transcribe aud
trs_save(aud_path, diar_out, out_json, out_path, model_dir)
# success
ret_code = 0
# copies landmark
LOGGER.debug(f"return code: {ret_code}")
if ret_code != 0:
return False
else:
# success if ret_code = 0
return True
"""
def find_hallucinations(annotation):
hall = [" Amara.org", "www.mesmerism.info", " QTSS", " ... "]
for i, tr in enumerate(annotation["segments"]):
for w in hall:
if w in tr["text"]:
LOGGER.debug(f"Hallucination removed: {tr['text']}")
annotation["segments"].pop(i)
break
return annotation
"""
from run_funs import run
try:
from common_utils import adapter
except ModuleNotFoundError:
from common_module.common_utils import adapter
try:
from common_utils import rabbitmq
except ModuleNotFoundError:
from common_module.common_utils import rabbitmq
if __name__ == "__main__":
Worker = rabbitmq.Worker()
Worker.register_callback(queue="queue_module_mmc_asr", callback=run)
this_adapter = adapter.Adapter(Worker)
this_adapter.start_listening()
import os
from pathlib import Path
from typeguard import typechecked
import asr_funs
try:
from common_utils import msg_builder, rabbitmq
except ModuleNotFoundError:
from common_module.common_utils import msg_builder, rabbitmq
# TODO
# get vid_dir from input msg
base_dir = os.path.join(os.environ["AI_FW_DIR"], "vids")
# TODO
# get model_dir from input msg
model_dir = os.path.join(os.environ["AI_FW_DIR"], "models", "mmc_asr")
Path(model_dir).mkdir(parents=True, exist_ok=True)
@typechecked
def run(message_body: dict, worker: rabbitmq.Worker) -> bool:
defs = {
# module name in msg
"mod_name": "mmc_asr",
# metadata key in output msg
"metadata_key": "transcription",
# metadata type in output msg
"metadata_type": "transcription",
# main key in output JSON
"out_json_key": "transcription",
# error msg if output JSON is not found
"not_found_msg": "cannot transcribe!",
# error msg if input msg is invalid
"invalid_msg": "External ID/UID/Application/Diar Required!",
}
extras = {"programme": {"module": defs["mod_name"]}}
if "programme" in message_body:
if "external_id" in message_body["programme"]:
# name of output JSON
defs["out_json"] = f'{message_body["programme"]["external_id"]}.json'
for k in msg_builder.handed_over_keys():
if k in message_body["programme"]:
extras["programme"][k] = message_body["programme"][k]
return msg_builder.build_msg(
message_body,
worker,
"mmc_asr",
asr_funs.dl_trs_save,
msg_builder.validate_message,
["external_id", "application", "uid", "mmc_aus"],
base_dir,
model_dir,
defs,
extras,
)
import sys
import os
from contextlib import contextmanager
from typing import Any, Dict
import numpy as np
import whisper_timestamped
from typeguard import typechecked
try:
from common_utils.logger import create_logger
except ModuleNotFoundError:
from common_module.common_utils.logger import create_logger
LOGGER = create_logger(__name__)
@contextmanager
def suppress_stdout():
# Auxiliary function to suppress Whisper logs (it is quite verbose)
# All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
@typechecked
class WhisperTranscriber:
def __init__(
self,
model: str,
device: str,
# dev_idx: int,
# compute_type: str,
# model_dir: str,
):
"""
model: name of the model or path to the model.
Examples:
- OpenAI-Whisper identifier: "large-v3", "medium.en", ...
- HuggingFace identifier: "openai/whisper-large-v3", "distil-whisper/distil-large-v2", ...
- File name: "path/to/model.pt", "path/to/model.ckpt", "path/to/model.bin"
- Folder name: "path/to/folder".
The folder must contain either "pytorch_model.bin", "model.safetensors",
or sharded versions of those, or "whisper.ckpt".
device : device to use. If None, use CUDA if there is a GPU available, otherwise CPU.
"""
self.model = whisper_timestamped.load_model(
model,
device,
# device_index=dev_idx,
# compute_type=compute_type,
# download_root=model_dir,
)
self._buffer = ""
def transcribe(
self, waveform: np.ndarray, options: Dict[str, Any]
) -> Dict[str, Any]:
"""
Transcribes audio using Whisper
"""
LOGGER.info(f"Transcribing...")
# Pad/trim audio to fit 30 s as required by Whisper
# tweaked_audio = whisper_timestamped.pad_or_trim(waveform)
# Transcribe the given audio while suppressing logs
# Whisper models can "hallucinate" text when given a segment w/o speech.
# This can be avoided by running VAD and gluing speech segments together
# before transcribing
with suppress_stdout():
transcription = whisper_timestamped.transcribe(
self.model,
waveform, # tweaked_audio,
beam_size=5,
best_of=5,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
vad="silero:v3.1",
# use past transcriptions to condition the model
initial_prompt=self._buffer,
verbose=True,
**options,
)
return transcription
def __call__(self, waveform: np.ndarray, options: Dict[str, Any]) -> Dict[str, Any]:
# transcribe
transcription = self.transcribe(waveform, options)
# upd transcription buffer
self._buffer += transcription["text"]
return transcription
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment