Commit bc8d73d9 authored by Mattia Bergagio's avatar Mattia Bergagio
Browse files

cleanup

parent 23ca7d86
# Ubuntu 20.04
# CUDA 11.6.2
# cuDNN 8
FROM nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
ENV TZ='Europe/Rome'
ENV BASE_FOLDER='/mmc_asr'
ENV LOGS_FOLDER='/LOGS'
ENV APP_USER='devuser'
ARG GIT_NAME
ARG GIT_TOKEN
RUN \
apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
python3.8 \
python3-pip \
tzdata \
ffmpeg \
git \
# TODO Other packets common to all images go in here
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/* && \
ln -snf "/usr/share/zoneinfo/$TZ" '/etc/localtime' && \
echo "$TZ" > '/etc/timezone'
RUN useradd -m "$APP_USER" && \
mkdir -p "$BASE_FOLDER" && \
chown -R "${APP_USER}:${APP_USER}" "$BASE_FOLDER" && \
mkdir -p "$LOGS_FOLDER" && \
chown -R "${APP_USER}:${APP_USER}" "$LOGS_FOLDER"
USER "$APP_USER"
WORKDIR "$BASE_FOLDER"
ENV PATH="${PATH}:/root/.local/bin:/home/${APP_USER}/.local/bin"
ENV PYTHONPATH="${PYTHONPATH}:${BASE_FOLDER}"
COPY --chown="${APP_USER}:${APP_USER}" /requirements.txt ./requirements.txt
RUN python3 -m pip install --no-cache-dir -r requirements.txt
WORKDIR "$BASE_FOLDER"/src
RUN git clone https://${GIT_NAME}:${GIT_TOKEN}@experts.mpai.community/software/mpai-aif/osd_tma/common_module.git
RUN pwd
RUN ls
WORKDIR ..
COPY --chown="${APP_USER}:${APP_USER}" /src ./src
CMD ["python3.8", "src/main.py"]
```
cd $PATH_SHARED
mkdir models
cd models
mkdir mmc_asr
cd mmc_asr
git lfs install
git clone https://huggingface.co/openai/whisper-large-v3
```
onnxruntime # VAD needs it
pika==1.3.1
pysrt
torchaudio # VAD needs it
transformers
typeguard==4.1.5
typing_extensions==4.8.0
whisper-timestamped
import os
from pathlib import Path
from typing import Dict, List
import torch
import whisper_timestamped
from whisper_timestamped.make_subtitles import write_srt
import pysrt
from typeguard import typechecked
import trs_class
try:
from common_utils.times import timeit
except ModuleNotFoundError:
from common_module.common_utils.times import timeit
try:
from common_utils.logger import create_logger
except ModuleNotFoundError:
from common_module.common_utils.logger import create_logger
try:
from common_utils.saves import save
except ModuleNotFoundError:
from common_module.common_utils.saves import save
try:
from common_utils.gpus_torch import pick_best_gpu
except ModuleNotFoundError:
from common_module.common_utils.gpus_torch import pick_best_gpu
LOGGER = create_logger(__name__)
@typechecked
def whisper_trs(
aud_path: str,
diar_segments: List[Dict],
device: str,
compute_type: str,
out_path: str,
model_dir: str,
dev_idx: int = 0,
max_len: int = 200,
) -> Dict[str, list]:
"""
diar_segments: segments from mmc_aus.
max_len: max length of a subtitle, in chars.
Transcribes each segment from mmc_aus.
"""
my_trs = trs_class.WhisperTranscriber(
os.path.join(model_dir, "whisper-large-v3"),
device,
# dev_idx,
# compute_type, # , model_dir
)
transcribe_options = {"task": "transcribe"}
chunks = []
end_srt_fils = []
# read WAVs from mmc_aus
mmc_aus_out_path = os.path.join(Path(out_path).parent, "mmc_aus")
# loop over segments from mmc_aus
for cnt_diar_segm, diar_segm in enumerate(diar_segments):
LOGGER.debug(f"{cnt_diar_segm=}")
LOGGER.debug(f'{diar_segm["start"]=}')
LOGGER.debug(f'{diar_segm["end"]=}')
if diar_segm["end"] - diar_segm["start"] < 1:
# suppress short utterances (pyannote artifact)
chunks.append(
{
"rel_start": transcr_d["segments"][0]["start"],
"rel_end": transcr_d["segments"][-1]["end"],
"text": None,
"language": None,
"speaker": None,
}
)
continue
diar_segm_path = os.path.join(mmc_aus_out_path, f"split.{cnt_diar_segm}.wav")
diar_audio = whisper_timestamped.load_audio(diar_segm_path)
# get transcription
transcr_d = my_trs(
diar_audio,
transcribe_options,
)
chunks.append(
{
"rel_start": transcr_d["segments"][0]["start"],
"rel_end": transcr_d["segments"][-1]["end"],
"text": transcr_d["text"],
"language": transcr_d["language"],
"speaker": diar_segm["label"],
}
)
LOGGER.debug(f"{chunks[-1]=}")
LOGGER.debug("writing subs...")
# https://github.com/linto-ai/whisper-timestamped/issues/42#issuecomment-1443866219
tmp_srt_fil = os.path.join(out_path, f"tmp_subs.{cnt_diar_segm}.srt")
end_srt_fils += [os.path.join(out_path, f"end_subs.{cnt_diar_segm}.srt")]
write_srt(
transcr_d["segments"],
open(
tmp_srt_fil,
"w",
encoding="utf-8",
),
)
LOGGER.debug("upd'ing subs...")
subs = pysrt.open(tmp_srt_fil, encoding="utf-8")
LOGGER.debug("shifting subs...")
subs.shift(seconds=diar_segm["start"])
LOGGER.debug("saving subs...")
subs.save(end_srt_fils[-1], encoding="utf-8")
# merge subs
# https://unix.stackexchange.com/a/728438
offset = 0
with open(
os.path.join(out_path, "end_subs.srt"), "w", encoding="utf-8"
) as end_subs_w:
for subi in end_srt_fils:
LOGGER.debug(f"reading {subi}...")
with open(subi, "r", encoding="utf-8") as subir:
for lin in subir:
# check if line contains a single integer
if lin.strip().isdigit():
# convert lin to int
num = int(lin)
num += offset
# write num to output file
end_subs_w.write(f"{num}\n")
else:
# write line
end_subs_w.write(lin)
# upd offset
offset = num
LOGGER.debug(f"{offset=}")
return {"transcription": chunks}
@timeit
@typechecked
def trs(
aud_path: str,
segments: List[Dict],
out_path: str,
model_dir: str,
) -> Dict[str, List]:
"""
aud_path: path of audio file.
segments: mmc_aus output.
out_path: path subs are saved to.
model_dir: dir the model is saved to.
Transcribes audio.
Returns annotation.
"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
LOGGER.debug("go to whisper_trs")
compute_type = "float16"
# pick best GPU
dev_idx = pick_best_gpu()
LOGGER.debug(f"using best GPU = {dev_idx}")
return whisper_trs(
aud_path, segments, device, compute_type, out_path, model_dir, dev_idx
)
except RuntimeError:
try:
# TODO
# try gpu 0?
device = "cpu"
compute_type = "int8"
LOGGER.debug("using CPU")
return whisper_trs(
aud_path, segments, device, compute_type, out_path, model_dir
)
except Exception as err:
LOGGER.error(f"{err=}")
return {"transcription": []}
def trs_save(
aud_path: str, segments: List[Dict], json_path: str, out_path: str, model_dir: str
):
"""
Transcribes audio.
Saves annotation to JSON.
"""
annotation = trs(aud_path, segments, out_path, model_dir)
save(annotation, json_path)
return
@typechecked
def dl_trs_save(
message_body: dict,
out_json: str,
out_path: str,
base_dir: str,
model_dir: str,
) -> bool:
"""
message_body: msg body. Has "FS" & "voices" as keys.
out_json: JSON the output is saved to.
out_path: unused input.
base_dir: base dir.
model_dir: dir the model is saved to.
Downloads mmc_aus JSON.
Transcribes audio.
Returns 0 if success.
"""
# copies landmark
ret_code = -1
aud_path = os.path.join(
base_dir,
message_body["programme"]["uid"],
f'{message_body["programme"]["external_id"]}.wav',
)
diar_out = message_body["programme"]["mmc_aus"]["segments"]
# check diar_out
# True if diar_out is a list
# True if all dicts have keys ["start", "end", "label"]
if isinstance(diar_out, list) and all(
[
# True if all keys in a dict are ["start", "end", "label"]
all(k in voice for k in ["start", "end", "label"])
for voice in diar_out
]
):
# transcribe aud
trs_save(aud_path, diar_out, out_json, out_path, model_dir)
# success
ret_code = 0
# copies landmark
LOGGER.debug(f"return code: {ret_code}")
if ret_code != 0:
return False
else:
# success if ret_code = 0
return True
"""
def find_hallucinations(annotation):
hall = [" Amara.org", "www.mesmerism.info", " QTSS", " ... "]
for i, tr in enumerate(annotation["segments"]):
for w in hall:
if w in tr["text"]:
LOGGER.debug(f"Hallucination removed: {tr['text']}")
annotation["segments"].pop(i)
break
return annotation
"""
from run_funs import run
try:
from common_utils import adapter
except ModuleNotFoundError:
from common_module.common_utils import adapter
try:
from common_utils import rabbitmq
except ModuleNotFoundError:
from common_module.common_utils import rabbitmq
if __name__ == "__main__":
Worker = rabbitmq.Worker()
Worker.register_callback(queue="queue_module_mmc_asr", callback=run)
this_adapter = adapter.Adapter(Worker)
this_adapter.start_listening()
import os
from pathlib import Path
from typeguard import typechecked
import asr_funs
try:
from common_utils import msg_builder, rabbitmq
except ModuleNotFoundError:
from common_module.common_utils import msg_builder, rabbitmq
# TODO
# get vid_dir from input msg
base_dir = os.path.join(os.environ["AI_FW_DIR"], "vids")
# TODO
# get model_dir from input msg
model_dir = os.path.join(os.environ["AI_FW_DIR"], "models", "mmc_asr")
Path(model_dir).mkdir(parents=True, exist_ok=True)
@typechecked
def run(message_body: dict, worker: rabbitmq.Worker) -> bool:
defs = {
# module name in msg
"mod_name": "mmc_asr",
# metadata key in output msg
"metadata_key": "transcription",
# metadata type in output msg
"metadata_type": "transcription",
# main key in output JSON
"out_json_key": "transcription",
# error msg if output JSON is not found
"not_found_msg": "cannot transcribe!",
# error msg if input msg is invalid
"invalid_msg": "External ID/UID/Application/Diar Required!",
}
extras = {"programme": {"module": defs["mod_name"]}}
if "programme" in message_body:
if "external_id" in message_body["programme"]:
# name of output JSON
defs["out_json"] = f'{message_body["programme"]["external_id"]}.json'
for k in msg_builder.handed_over_keys():
if k in message_body["programme"]:
extras["programme"][k] = message_body["programme"][k]
return msg_builder.build_msg(
message_body,
worker,
"mmc_asr",
asr_funs.dl_trs_save,
msg_builder.validate_message,
["external_id", "application", "uid", "mmc_aus"],
base_dir,
model_dir,
defs,
extras,
)
import sys
import os
from contextlib import contextmanager
from typing import Any, Dict
import numpy as np
import whisper_timestamped
from typeguard import typechecked
try:
from common_utils.logger import create_logger
except ModuleNotFoundError:
from common_module.common_utils.logger import create_logger
LOGGER = create_logger(__name__)
@contextmanager
def suppress_stdout():
# Auxiliary function to suppress Whisper logs (it is quite verbose)
# All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
@typechecked
class WhisperTranscriber:
def __init__(
self,
model: str,
device: str,
# dev_idx: int,
# compute_type: str,
# model_dir: str,
):
"""
model: name of the model or path to the model.
Examples:
- OpenAI-Whisper identifier: "large-v3", "medium.en", ...
- HuggingFace identifier: "openai/whisper-large-v3", "distil-whisper/distil-large-v2", ...
- File name: "path/to/model.pt", "path/to/model.ckpt", "path/to/model.bin"
- Folder name: "path/to/folder".
The folder must contain either "pytorch_model.bin", "model.safetensors",
or sharded versions of those, or "whisper.ckpt".
device : device to use. If None, use CUDA if there is a GPU available, otherwise CPU.
"""
self.model = whisper_timestamped.load_model(
model,
device,
# device_index=dev_idx,
# compute_type=compute_type,
# download_root=model_dir,
)
self._buffer = ""
def transcribe(
self, waveform: np.ndarray, options: Dict[str, Any]
) -> Dict[str, Any]:
"""
Transcribes audio using Whisper
"""
LOGGER.info(f"Transcribing...")
# Pad/trim audio to fit 30 s as required by Whisper
# tweaked_audio = whisper_timestamped.pad_or_trim(waveform)
# Transcribe the given audio while suppressing logs
# Whisper models can "hallucinate" text when given a segment w/o speech.
# This can be avoided by running VAD and gluing speech segments together
# before transcribing
with suppress_stdout():
transcription = whisper_timestamped.transcribe(
self.model,
waveform, # tweaked_audio,
beam_size=5,
best_of=5,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
vad="silero:v3.1",
# use past transcriptions to condition the model
initial_prompt=self._buffer,
verbose=True,
**options,
)
return transcription
def __call__(self, waveform: np.ndarray, options: Dict[str, Any]) -> Dict[str, Any]:
# transcribe
transcription = self.transcribe(waveform, options)
# upd transcription buffer
self._buffer += transcription["text"]
return transcription
Implementation of MMC-ASR as a class in Python.
## Installation
Code was designed and tested on an Ubuntu 20.04 operating system using anaconda 23.7.2 and Python 3.9.
An environment with all the necessary libraries can be created using:
```bash
conda create --name <env> --file requirements.txt
```
# MPAI-MMC Automatic Speech Recognition
This code refers to the implementation of the MMC-ASR, as described in the [AIM](https://mpai.community/standards/mpai-mmc/v2-2/ai-modules/automatic-speech-recognition/).
### Guide to the ASR code #1
The code takes Speech Objects from MMC-AUS and generates Text Segments (called text transcripts). It uses the whisper-large-v3 model to convert an input Speech Object (speaker’s turn) into a Text Segment (here called text transcript). Disfluencies (e.g., repetitions, repairs, filled pauses) are often omitted. The Whisper reference document is available.
The MMC-ASR Reference Software is found at the MPAI gitlab site. Use of this AI Modules is for developers who are familiar with Python, Docker, RabbitMQ, and downloading models from HuggingFace. The Reference Software contains:
1. src: a folder with the Python code implementing the AIM
2. Dockerfile: a Docker file containing only the libraries required to build the Docker image and run the container
3. requirements.txt: dependencies installed in the Docker image
4. README.md: commands for cloning https://huggingface.co/openai/whisper-large-v3
Library: https://github.com/linto-ai/whisper-timestamped
### Guide to the ASR code #2
Use of this AI Modules is for developers who are familiar with Python and downloading models from HuggingFace,
Use of this AI Module is for developers who are familiar with Python and downloading models from HuggingFace,
A wrapper for the Whisper NN Module:
......@@ -26,8 +10,17 @@ A wrapper for the Whisper NN Module:
2. Performs Speech Recognition on each Speech Object by executing the Whisper Module.
3. Outputs Recognised Text.
The MMC-ASR Reference Software is found at the NNW gitlab site (registration required). It contains:
The MMC-ASR Reference Software is found at the NNW gitlab site (registration required). It contains:
1. The python code implementing the AIM.
2. The required libraries are: pytorch and transformers (HuggingFace).
Implementation of MMC-ASR as a class in Python.
## Installation
Code was designed and tested on an Ubuntu 20.04 operating system using anaconda 23.7.2 and Python 3.9.
An environment with all the necessary libraries can be created using:
```bash
conda create --name <env> --file requirements.txt
```
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment