reab5555 commited on
Commit
7c1ee96
·
verified ·
1 Parent(s): 7bbb7f4

Update voice_analysis.py

Browse files
Files changed (1) hide show
  1. voice_analysis.py +10 -4
voice_analysis.py CHANGED
@@ -30,20 +30,26 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
30
  raise ValueError("py_annote_hf_token environment variable is not set. Please check your Hugging Face Space's Variables and secrets section.")
31
 
32
  model = Model.from_pretrained(model_name, use_auth_token=hf_token)
 
33
 
34
  waveform, sample_rate = torchaudio.load(audio_path)
35
  embeddings = []
36
 
37
  for turn, _, speaker in diarization.itertracks(yield_label=True):
38
- start = int(turn.start * sample_rate)
39
- end = int(turn.end * sample_rate)
40
 
41
- segment = waveform[:, start:end]
42
  if segment.shape[1] == 0:
43
  continue
44
 
 
 
 
 
 
45
  with torch.no_grad():
46
- embedding = model({"waveform": segment, "sample_rate": sample_rate})
47
 
48
  embeddings.append({"time": turn.start, "embedding": embedding.squeeze().cpu().numpy(), "speaker": speaker})
49
 
 
30
  raise ValueError("py_annote_hf_token environment variable is not set. Please check your Hugging Face Space's Variables and secrets section.")
31
 
32
  model = Model.from_pretrained(model_name, use_auth_token=hf_token)
33
+ model.eval() # Set the model to evaluation mode
34
 
35
  waveform, sample_rate = torchaudio.load(audio_path)
36
  embeddings = []
37
 
38
  for turn, _, speaker in diarization.itertracks(yield_label=True):
39
+ start_frame = int(turn.start * sample_rate)
40
+ end_frame = int(turn.end * sample_rate)
41
 
42
+ segment = waveform[:, start_frame:end_frame]
43
  if segment.shape[1] == 0:
44
  continue
45
 
46
+ # Ensure the segment is long enough (at least 1 second)
47
+ if segment.shape[1] < sample_rate:
48
+ padding = torch.zeros(1, sample_rate - segment.shape[1])
49
+ segment = torch.cat([segment, padding], dim=1)
50
+
51
  with torch.no_grad():
52
+ embedding = model(segment) # Pass the tensor directly, not a dictionary
53
 
54
  embeddings.append({"time": turn.start, "embedding": embedding.squeeze().cpu().numpy(), "speaker": speaker})
55