Spaces:
Build error
Build error
""" | |
Nemo diarizer | |
""" | |
import os | |
import json | |
import wget | |
import matplotlib.pyplot as plt | |
from omegaconf import OmegaConf | |
from nemo.collections.asr.models import ClusteringDiarizer | |
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object | |
from pyannote.core import notebook | |
from diarizers.diarizer import Diarizer | |
class NemoDiarizer(Diarizer): | |
"""Class for Nemo Diarizer""" | |
def __init__(self, audio_path: str, data_dir: str): | |
""" | |
Nemo diarizer class | |
Args: | |
audio_path (str): the path to the audio file | |
""" | |
self.audio_path = audio_path | |
self.data_dir = data_dir | |
self.diarization = None | |
self.manifest_dir = os.path.join(self.data_dir, 'input_manifest.json') | |
self.model_config = os.path.join(self.data_dir, 'offline_diarization.yaml') | |
if not os.path.exists(self.model_config): | |
config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/" \ | |
"speaker_tasks/diarization/conf/offline_diarization.yaml" | |
self.model_config = wget.download(config_url, self.data_dir) | |
self.config = OmegaConf.load(self.model_config) | |
def _create_manifest_file(self): | |
""" | |
Function that creates inference manifest file | |
""" | |
meta = { | |
'audio_filepath': self.audio_path, | |
'offset': 0, | |
'duration': None, | |
'label': 'infer', | |
'text': '-', | |
'num_speakers': None, | |
'rttm_filepath': None, | |
'uem_filepath': None | |
} | |
with open(self.manifest_dir, 'w') as fp: | |
json.dump(meta, fp) | |
fp.write('\n') | |
def _apply_config(self, pretrained_speaker_model: str = 'ecapa_tdnn'): | |
""" | |
Function that edits the inference configuration file | |
Args: | |
pretrained_speaker_model (str): the pre-trained embedding model options are | |
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn) | |
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/ | |
speaker_diarization/results.html | |
""" | |
pretrained_vad = 'vad_marblenet' | |
output_dir = os.path.join(self.data_dir, 'outputs') | |
self.config.diarizer.manifest_filepath = self.manifest_dir | |
self.config.diarizer.out_dir = output_dir | |
self.config.diarizer.ignore_overlap = False | |
self.config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model | |
self.config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5 | |
self.config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75 | |
self.config.diarizer.oracle_vad = False | |
self.config.diarizer.clustering.parameters.oracle_num_speakers = False | |
# Here we use our inhouse pretrained NeMo VAD | |
self.config.diarizer.vad.model_path = pretrained_vad | |
self.config.diarizer.vad.window_length_in_sec = 0.15 | |
self.config.diarizer.vad.shift_length_in_sec = 0.01 | |
self.config.diarizer.vad.parameters.onset = 0.8 | |
self.config.diarizer.vad.parameters.offset = 0.6 | |
self.config.diarizer.vad.parameters.min_duration_on = 0.1 | |
self.config.diarizer.vad.parameters.min_duration_off = 0.4 | |
def diarize_audio(self, pretrained_speaker_model: str = 'ecapa_tdnn'): | |
""" | |
function that diarizes the audio | |
Args: | |
pretrained_speaker_model (str): the pre-trained embedding model options are | |
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn) | |
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/ | |
speaker_diarization/results.html | |
""" | |
self._create_manifest_file() | |
self._apply_config(pretrained_speaker_model) | |
sd_model = ClusteringDiarizer(cfg=self.config) | |
sd_model.diarize() | |
audio_file_name_without_extension = os.path.basename(self.audio_path).rsplit('.', 1)[0] | |
output_diarization_pred = f'{self.data_dir}/outputs/pred_rttms/' \ | |
f'{audio_file_name_without_extension}.rttm' | |
pred_labels = rttm_to_labels(output_diarization_pred) | |
self.diarization = labels_to_pyannote_object(pred_labels) | |
def get_diarization_figure(self) -> plt.gcf: | |
""" | |
Function that return the diarization figure | |
""" | |
if not self.diarization: | |
self.diarize_audio() | |
figure, ax = plt.subplots() | |
notebook.plot_annotation(self.diarization, ax=ax, time=True, legend=True) | |
return plt.gcf() | |