Spaces:
Build error
Build error
File size: 4,681 Bytes
4ba35bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
Nemo diarizer
"""
import os
import json
import wget
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from nemo.collections.asr.models import ClusteringDiarizer
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object
from pyannote.core import notebook
from diarizers.diarizer import Diarizer
class NemoDiarizer(Diarizer):
"""Class for Nemo Diarizer"""
def __init__(self, audio_path: str, data_dir: str):
"""
Nemo diarizer class
Args:
audio_path (str): the path to the audio file
"""
self.audio_path = audio_path
self.data_dir = data_dir
self.diarization = None
self.manifest_dir = os.path.join(self.data_dir, 'input_manifest.json')
self.model_config = os.path.join(self.data_dir, 'offline_diarization.yaml')
if not os.path.exists(self.model_config):
config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/" \
"speaker_tasks/diarization/conf/offline_diarization.yaml"
self.model_config = wget.download(config_url, self.data_dir)
self.config = OmegaConf.load(self.model_config)
def _create_manifest_file(self):
"""
Function that creates inference manifest file
"""
meta = {
'audio_filepath': self.audio_path,
'offset': 0,
'duration': None,
'label': 'infer',
'text': '-',
'num_speakers': None,
'rttm_filepath': None,
'uem_filepath': None
}
with open(self.manifest_dir, 'w') as fp:
json.dump(meta, fp)
fp.write('\n')
def _apply_config(self, pretrained_speaker_model: str = 'ecapa_tdnn'):
"""
Function that edits the inference configuration file
Args:
pretrained_speaker_model (str): the pre-trained embedding model options are
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn)
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/
speaker_diarization/results.html
"""
pretrained_vad = 'vad_marblenet'
output_dir = os.path.join(self.data_dir, 'outputs')
self.config.diarizer.manifest_filepath = self.manifest_dir
self.config.diarizer.out_dir = output_dir
self.config.diarizer.ignore_overlap = False
self.config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
self.config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 1.5
self.config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.75
self.config.diarizer.oracle_vad = False
self.config.diarizer.clustering.parameters.oracle_num_speakers = False
# Here we use our inhouse pretrained NeMo VAD
self.config.diarizer.vad.model_path = pretrained_vad
self.config.diarizer.vad.window_length_in_sec = 0.15
self.config.diarizer.vad.shift_length_in_sec = 0.01
self.config.diarizer.vad.parameters.onset = 0.8
self.config.diarizer.vad.parameters.offset = 0.6
self.config.diarizer.vad.parameters.min_duration_on = 0.1
self.config.diarizer.vad.parameters.min_duration_off = 0.4
def diarize_audio(self, pretrained_speaker_model: str = 'ecapa_tdnn'):
"""
function that diarizes the audio
Args:
pretrained_speaker_model (str): the pre-trained embedding model options are
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn)
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/
speaker_diarization/results.html
"""
self._create_manifest_file()
self._apply_config(pretrained_speaker_model)
sd_model = ClusteringDiarizer(cfg=self.config)
sd_model.diarize()
audio_file_name_without_extension = os.path.basename(self.audio_path).rsplit('.', 1)[0]
output_diarization_pred = f'{self.data_dir}/outputs/pred_rttms/' \
f'{audio_file_name_without_extension}.rttm'
pred_labels = rttm_to_labels(output_diarization_pred)
self.diarization = labels_to_pyannote_object(pred_labels)
def get_diarization_figure(self) -> plt.gcf:
"""
Function that return the diarization figure
"""
if not self.diarization:
self.diarize_audio()
figure, ax = plt.subplots()
notebook.plot_annotation(self.diarization, ax=ax, time=True, legend=True)
return plt.gcf()
|