Update voice_analysis.py
Browse files- voice_analysis.py +9 -2
voice_analysis.py
CHANGED
@@ -22,12 +22,16 @@ def diarize_speakers(audio_path):
|
|
22 |
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
|
23 |
diarization = pipeline(audio_path)
|
24 |
return diarization
|
25 |
-
|
26 |
def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embedding"):
|
27 |
model = Model.from_pretrained(model_name, use_auth_token=os.environ.get("py_annote_hf_token"))
|
28 |
waveform, sample_rate = torchaudio.load(audio_path)
|
29 |
duration = waveform.shape[1] / sample_rate
|
30 |
|
|
|
|
|
|
|
|
|
31 |
embeddings = []
|
32 |
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
33 |
start_frame = int(turn.start * sample_rate)
|
@@ -35,8 +39,11 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
|
|
35 |
segment = waveform[:, start_frame:end_frame]
|
36 |
|
37 |
if segment.shape[1] > 0:
|
|
|
|
|
|
|
38 |
with torch.no_grad():
|
39 |
-
embedding = model(segment
|
40 |
embeddings.append({"time": turn.start, "duration": turn.duration, "embedding": embedding.cpu().numpy(), "speaker": speaker})
|
41 |
|
42 |
# Ensure embeddings cover the entire duration
|
|
|
22 |
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
|
23 |
diarization = pipeline(audio_path)
|
24 |
return diarization
|
25 |
+
|
26 |
def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embedding"):
|
27 |
model = Model.from_pretrained(model_name, use_auth_token=os.environ.get("py_annote_hf_token"))
|
28 |
waveform, sample_rate = torchaudio.load(audio_path)
|
29 |
duration = waveform.shape[1] / sample_rate
|
30 |
|
31 |
+
# Convert stereo to mono if necessary
|
32 |
+
if waveform.shape[0] == 2:
|
33 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
34 |
+
|
35 |
embeddings = []
|
36 |
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
37 |
start_frame = int(turn.start * sample_rate)
|
|
|
39 |
segment = waveform[:, start_frame:end_frame]
|
40 |
|
41 |
if segment.shape[1] > 0:
|
42 |
+
# Ensure the segment is on the correct device
|
43 |
+
segment = segment.to(model.device)
|
44 |
+
|
45 |
with torch.no_grad():
|
46 |
+
embedding = model(segment)
|
47 |
embeddings.append({"time": turn.start, "duration": turn.duration, "embedding": embedding.cpu().numpy(), "speaker": speaker})
|
48 |
|
49 |
# Ensure embeddings cover the entire duration
|