reab5555 commited on
Commit
36f6fd8
·
verified ·
1 Parent(s): 1a48270

Update voice_analysis.py

Browse files
Files changed (1) hide show
  1. voice_analysis.py +9 -2
voice_analysis.py CHANGED
@@ -22,12 +22,16 @@ def diarize_speakers(audio_path):
22
  pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
23
  diarization = pipeline(audio_path)
24
  return diarization
25
-
26
  def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embedding"):
27
  model = Model.from_pretrained(model_name, use_auth_token=os.environ.get("py_annote_hf_token"))
28
  waveform, sample_rate = torchaudio.load(audio_path)
29
  duration = waveform.shape[1] / sample_rate
30
 
 
 
 
 
31
  embeddings = []
32
  for turn, _, speaker in diarization.itertracks(yield_label=True):
33
  start_frame = int(turn.start * sample_rate)
@@ -35,8 +39,11 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
35
  segment = waveform[:, start_frame:end_frame]
36
 
37
  if segment.shape[1] > 0:
 
 
 
38
  with torch.no_grad():
39
- embedding = model(segment.to(model.device))
40
  embeddings.append({"time": turn.start, "duration": turn.duration, "embedding": embedding.cpu().numpy(), "speaker": speaker})
41
 
42
  # Ensure embeddings cover the entire duration
 
22
  pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
23
  diarization = pipeline(audio_path)
24
  return diarization
25
+
26
  def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embedding"):
27
  model = Model.from_pretrained(model_name, use_auth_token=os.environ.get("py_annote_hf_token"))
28
  waveform, sample_rate = torchaudio.load(audio_path)
29
  duration = waveform.shape[1] / sample_rate
30
 
31
+ # Convert stereo to mono if necessary
32
+ if waveform.shape[0] == 2:
33
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
34
+
35
  embeddings = []
36
  for turn, _, speaker in diarization.itertracks(yield_label=True):
37
  start_frame = int(turn.start * sample_rate)
 
39
  segment = waveform[:, start_frame:end_frame]
40
 
41
  if segment.shape[1] > 0:
42
+ # Ensure the segment is on the correct device
43
+ segment = segment.to(model.device)
44
+
45
  with torch.no_grad():
46
+ embedding = model(segment)
47
  embeddings.append({"time": turn.start, "duration": turn.duration, "embedding": embedding.cpu().numpy(), "speaker": speaker})
48
 
49
  # Ensure embeddings cover the entire duration