reab5555 commited on
Commit
08d515b
·
verified ·
1 Parent(s): 931d60e

Update voice_analysis.py

Browse files
Files changed (1) hide show
  1. voice_analysis.py +32 -6
voice_analysis.py CHANGED
@@ -32,6 +32,10 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
32
  if waveform.shape[0] == 2:
33
  waveform = torch.mean(waveform, dim=0, keepdim=True)
34
 
 
 
 
 
35
  embeddings = []
36
  for turn, _, speaker in diarization.itertracks(yield_label=True):
37
  start_frame = int(turn.start * sample_rate)
@@ -39,16 +43,38 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
39
  segment = waveform[:, start_frame:end_frame]
40
 
41
  if segment.shape[1] > 0:
42
- # Ensure the segment is on the correct device
43
- segment = segment.to(model.device)
 
 
44
 
45
- with torch.no_grad():
46
- embedding = model(segment)
47
- embeddings.append({"time": turn.start, "duration": turn.duration, "embedding": embedding.cpu().numpy(), "speaker": speaker})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Ensure embeddings cover the entire duration
50
  if embeddings and embeddings[-1]['time'] + embeddings[-1]['duration'] < duration:
51
- embeddings.append({"time": duration, "duration": 0, "embedding": np.zeros_like(embeddings[0]['embedding']), "speaker": "silence"})
 
 
 
 
 
52
 
53
  return embeddings, duration
54
 
 
32
  if waveform.shape[0] == 2:
33
  waveform = torch.mean(waveform, dim=0, keepdim=True)
34
 
35
+ # Minimum segment duration (in seconds)
36
+ min_segment_duration = 0.5
37
+ min_segment_length = int(min_segment_duration * sample_rate)
38
+
39
  embeddings = []
40
  for turn, _, speaker in diarization.itertracks(yield_label=True):
41
  start_frame = int(turn.start * sample_rate)
 
43
  segment = waveform[:, start_frame:end_frame]
44
 
45
  if segment.shape[1] > 0:
46
+ # Pad short segments
47
+ if segment.shape[1] < min_segment_length:
48
+ padding = torch.zeros(1, min_segment_length - segment.shape[1])
49
+ segment = torch.cat([segment, padding], dim=1)
50
 
51
+ # Split long segments
52
+ for i in range(0, segment.shape[1], min_segment_length):
53
+ sub_segment = segment[:, i:i+min_segment_length]
54
+ if sub_segment.shape[1] < min_segment_length:
55
+ padding = torch.zeros(1, min_segment_length - sub_segment.shape[1])
56
+ sub_segment = torch.cat([sub_segment, padding], dim=1)
57
+
58
+ # Ensure the segment is on the correct device
59
+ sub_segment = sub_segment.to(model.device)
60
+
61
+ with torch.no_grad():
62
+ embedding = model(sub_segment)
63
+ embeddings.append({
64
+ "time": turn.start + i / sample_rate,
65
+ "duration": min_segment_duration,
66
+ "embedding": embedding.cpu().numpy(),
67
+ "speaker": speaker
68
+ })
69
 
70
  # Ensure embeddings cover the entire duration
71
  if embeddings and embeddings[-1]['time'] + embeddings[-1]['duration'] < duration:
72
+ embeddings.append({
73
+ "time": duration,
74
+ "duration": 0,
75
+ "embedding": np.zeros_like(embeddings[0]['embedding']),
76
+ "speaker": "silence"
77
+ })
78
 
79
  return embeddings, duration
80