File size: 1,735 Bytes
429df62
 
 
 
 
4f38470
429df62
 
4f38470
429df62
 
 
4f38470
429df62
 
 
 
4f38470
 
 
 
 
429df62
 
4f38470
 
 
 
 
 
 
 
 
 
 
 
429df62
 
 
 
4f38470
 
 
429df62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Setup:
#     pip install pyannote.audio>=3.1
# Requirement: Sumit access request for the following models.
#     https://huggingface.co/pyannote/speaker-diarization-3.1
#     https://huggingface.co/pyannote/segmentation-3.0
# wget https://huggingface.co/kotoba-tech/kotoba-whisper-v2.2/resolve/main/sample_audio/sample_diarization_japanese.mp3
import soundfile as sf
import numpy as np
from typing import Union, Dict, List

import torch
from pyannote.audio import Pipeline
from diarizers import SegmentationModel


class SpeakerDiarization:

    def __init__(self):
        self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
        self.pipeline._segmentation.model = SegmentationModel().from_pretrained(
            'diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn'
        ).to_pyannote_model()

    def __call__(self,
                 audio: Union[torch.Tensor, np.ndarray],
                 sampling_rate: int) -> Dict[str, List[List[float]]]:
        if sampling_rate is None:
            raise ValueError("sampling_rate must be provided")
        if type(audio) is np.ndarray:
            audio = torch.as_tensor(audio)
        audio = torch.as_tensor(audio, dtype=torch.float32)
        if len(audio.shape) == 1:
            audio = audio.unsqueeze(0)
        elif len(audio.shape) > 3:
            raise ValueError("audio shape must be (channel, time)")
        audio = {"waveform": audio, "sample_rate": sampling_rate}
        output = self.pipeline(audio)
        return {s: [[i.start, i.end] for i in output.label_timeline(s)] for s in output.labels()}


pipeline = SpeakerDiarization()
a, sr = sf.read("sample_diarization_japanese.mp3")
print(pipeline(a.T, sampling_rate=sr))