|
|
|
|
|
|
|
|
|
|
|
|
|
import soundfile as sf |
|
import numpy as np |
|
from typing import Union, Dict, List |
|
|
|
import torch |
|
from pyannote.audio import Pipeline |
|
from diarizers import SegmentationModel |
|
|
|
|
|
class SpeakerDiarization: |
|
|
|
def __init__(self): |
|
self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") |
|
self.pipeline._segmentation.model = SegmentationModel().from_pretrained( |
|
'diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn' |
|
).to_pyannote_model() |
|
|
|
def __call__(self, |
|
audio: Union[torch.Tensor, np.ndarray], |
|
sampling_rate: int) -> Dict[str, List[List[float]]]: |
|
if sampling_rate is None: |
|
raise ValueError("sampling_rate must be provided") |
|
if type(audio) is np.ndarray: |
|
audio = torch.as_tensor(audio) |
|
audio = torch.as_tensor(audio, dtype=torch.float32) |
|
if len(audio.shape) == 1: |
|
audio = audio.unsqueeze(0) |
|
elif len(audio.shape) > 3: |
|
raise ValueError("audio shape must be (channel, time)") |
|
audio = {"waveform": audio, "sample_rate": sampling_rate} |
|
output = self.pipeline(audio) |
|
return {s: [[i.start, i.end] for i in output.label_timeline(s)] for s in output.labels()} |
|
|
|
|
|
pipeline = SpeakerDiarization() |
|
a, sr = sf.read("sample_diarization_japanese.mp3") |
|
print(pipeline(a.T, sampling_rate=sr)) |
|
|
|
|