reab5555 commited on
Commit
4a67bd7
·
verified ·
1 Parent(s): bafab47

Update voice_analysis.py

Browse files
Files changed (1) hide show
  1. voice_analysis.py +12 -37
voice_analysis.py CHANGED
@@ -24,51 +24,26 @@ def diarize_speakers(audio_path):
24
  return diarization
25
 
26
  def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embedding"):
27
- hf_token = os.environ.get("py_annote_hf_token")
28
-
29
- if not hf_token:
30
- raise ValueError("py_annote_hf_token environment variable is not set. Please check your Hugging Face Space's Variables and secrets section.")
31
-
32
- model = Model.from_pretrained(model_name, use_auth_token=hf_token)
33
- model.eval() # Set the model to evaluation mode
34
-
35
  waveform, sample_rate = torchaudio.load(audio_path)
36
- print(f"Sample rate: {sample_rate}")
37
- print(f"Waveform shape: {waveform.shape}")
38
-
39
- # Convert stereo to mono if necessary
40
- if waveform.shape[0] == 2:
41
- waveform = torch.mean(waveform, dim=0, keepdim=True)
42
 
43
  embeddings = []
44
-
45
  for turn, _, speaker in diarization.itertracks(yield_label=True):
46
  start_frame = int(turn.start * sample_rate)
47
  end_frame = int(turn.end * sample_rate)
48
-
49
  segment = waveform[:, start_frame:end_frame]
50
- print(f"Segment shape before processing: {segment.shape}")
51
 
52
- if segment.shape[1] == 0:
53
- continue
54
-
55
- # Ensure the segment is long enough (at least 2 seconds)
56
- if segment.shape[1] < 2 * sample_rate:
57
- padding = torch.zeros(1, 2 * sample_rate - segment.shape[1])
58
- segment = torch.cat([segment, padding], dim=1)
59
-
60
- # Ensure the segment is not too long (maximum 10 seconds)
61
- if segment.shape[1] > 10 * sample_rate:
62
- segment = segment[:, :10 * sample_rate]
63
-
64
- print(f"Segment shape after processing: {segment.shape}")
65
-
66
- with torch.no_grad():
67
- embedding = model(segment) # Pass the tensor directly, not a dictionary
68
-
69
- embeddings.append({"time": turn.start, "embedding": embedding.squeeze().cpu().numpy(), "speaker": speaker})
70
-
71
- return embeddings
72
 
73
  def align_voice_embeddings(voice_embeddings, frame_count, fps):
74
  aligned_embeddings = []
 
24
  return diarization
25
 
26
  def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embedding"):
27
+ model = Model.from_pretrained(model_name, use_auth_token=os.environ.get("py_annote_hf_token"))
 
 
 
 
 
 
 
28
  waveform, sample_rate = torchaudio.load(audio_path)
29
+ duration = waveform.shape[1] / sample_rate
 
 
 
 
 
30
 
31
  embeddings = []
 
32
  for turn, _, speaker in diarization.itertracks(yield_label=True):
33
  start_frame = int(turn.start * sample_rate)
34
  end_frame = int(turn.end * sample_rate)
 
35
  segment = waveform[:, start_frame:end_frame]
 
36
 
37
+ if segment.shape[1] > 0:
38
+ with torch.no_grad():
39
+ embedding = model(segment.to(model.device))
40
+ embeddings.append({"time": turn.start, "duration": turn.duration, "embedding": embedding.cpu().numpy(), "speaker": speaker})
41
+
42
+ # Ensure embeddings cover the entire duration
43
+ if embeddings and embeddings[-1]['time'] + embeddings[-1]['duration'] < duration:
44
+ embeddings.append({"time": duration, "duration": 0, "embedding": np.zeros_like(embeddings[0]['embedding']), "speaker": "silence"})
45
+
46
+ return embeddings, duration
 
 
 
 
 
 
 
 
 
 
47
 
48
  def align_voice_embeddings(voice_embeddings, frame_count, fps):
49
  aligned_embeddings = []