reab5555 commited on
Commit
38c3415
·
verified ·
1 Parent(s): 9263021

Update voice_analysis.py

Browse files
Files changed (1) hide show
  1. voice_analysis.py +6 -3
voice_analysis.py CHANGED
@@ -34,6 +34,12 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
34
 
35
  waveform, sample_rate = torchaudio.load(audio_path)
36
  print(f"Sample rate: {sample_rate}")
 
 
 
 
 
 
37
  embeddings = []
38
 
39
  for turn, _, speaker in diarization.itertracks(yield_label=True):
@@ -55,9 +61,6 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
55
  if segment.shape[1] > 10 * sample_rate:
56
  segment = segment[:, :10 * sample_rate]
57
 
58
- # Reshape the segment to match the model's expected input
59
- segment = segment.unsqueeze(0) # Add batch dimension
60
-
61
  print(f"Segment shape after processing: {segment.shape}")
62
 
63
  with torch.no_grad():
 
34
 
35
  waveform, sample_rate = torchaudio.load(audio_path)
36
  print(f"Sample rate: {sample_rate}")
37
+ print(f"Waveform shape: {waveform.shape}")
38
+
39
+ # Convert stereo to mono if necessary
40
+ if waveform.shape[0] == 2:
41
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
42
+
43
  embeddings = []
44
 
45
  for turn, _, speaker in diarization.itertracks(yield_label=True):
 
61
  if segment.shape[1] > 10 * sample_rate:
62
  segment = segment[:, :10 * sample_rate]
63
 
 
 
 
64
  print(f"Segment shape after processing: {segment.shape}")
65
 
66
  with torch.no_grad():