kotoba-whisper-v2.2 / pipeline /test_speaker_diarization.py
asahi417's picture
init
4f38470
raw
history blame
1.74 kB
# Setup:
# pip install pyannote.audio>=3.1
# Requirement: Sumit access request for the following models.
# https://huggingface.co/pyannote/speaker-diarization-3.1
# https://huggingface.co/pyannote/segmentation-3.0
# wget https://huggingface.co/kotoba-tech/kotoba-whisper-v2.2/resolve/main/sample_audio/sample_diarization_japanese.mp3
import soundfile as sf
import numpy as np
from typing import Union, Dict, List
import torch
from pyannote.audio import Pipeline
from diarizers import SegmentationModel
class SpeakerDiarization:
def __init__(self):
self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
self.pipeline._segmentation.model = SegmentationModel().from_pretrained(
'diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn'
).to_pyannote_model()
def __call__(self,
audio: Union[torch.Tensor, np.ndarray],
sampling_rate: int) -> Dict[str, List[List[float]]]:
if sampling_rate is None:
raise ValueError("sampling_rate must be provided")
if type(audio) is np.ndarray:
audio = torch.as_tensor(audio)
audio = torch.as_tensor(audio, dtype=torch.float32)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
elif len(audio.shape) > 3:
raise ValueError("audio shape must be (channel, time)")
audio = {"waveform": audio, "sample_rate": sampling_rate}
output = self.pipeline(audio)
return {s: [[i.start, i.end] for i in output.label_timeline(s)] for s in output.labels()}
pipeline = SpeakerDiarization()
a, sr = sf.read("sample_diarization_japanese.mp3")
print(pipeline(a.T, sampling_rate=sr))