Update voice_analysis.py
Browse files- voice_analysis.py +10 -4
voice_analysis.py
CHANGED
@@ -43,10 +43,17 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
|
|
43 |
if segment.shape[1] == 0:
|
44 |
continue
|
45 |
|
46 |
-
# Ensure the segment is long enough (at least
|
47 |
-
if segment.shape[1] < sample_rate:
|
48 |
-
padding = torch.zeros(1, sample_rate - segment.shape[1])
|
49 |
segment = torch.cat([segment, padding], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
with torch.no_grad():
|
52 |
embedding = model(segment) # Pass the tensor directly, not a dictionary
|
@@ -56,7 +63,6 @@ def get_speaker_embeddings(audio_path, diarization, model_name="pyannote/embeddi
|
|
56 |
return embeddings
|
57 |
|
58 |
def align_voice_embeddings(voice_embeddings, frame_count, fps):
|
59 |
-
import numpy as np
|
60 |
aligned_embeddings = []
|
61 |
current_embedding_index = 0
|
62 |
|
|
|
43 |
if segment.shape[1] == 0:
|
44 |
continue
|
45 |
|
46 |
+
# Ensure the segment is long enough (at least 2 seconds)
|
47 |
+
if segment.shape[1] < 2 * sample_rate:
|
48 |
+
padding = torch.zeros(1, 2 * sample_rate - segment.shape[1])
|
49 |
segment = torch.cat([segment, padding], dim=1)
|
50 |
+
|
51 |
+
# Ensure the segment is not too long (maximum 10 seconds)
|
52 |
+
if segment.shape[1] > 10 * sample_rate:
|
53 |
+
segment = segment[:, :10 * sample_rate]
|
54 |
+
|
55 |
+
# Reshape the segment to match the model's expected input
|
56 |
+
segment = segment.unsqueeze(0) # Add batch dimension
|
57 |
|
58 |
with torch.no_grad():
|
59 |
embedding = model(segment) # Pass the tensor directly, not a dictionary
|
|
|
63 |
return embeddings
|
64 |
|
65 |
def align_voice_embeddings(voice_embeddings, frame_count, fps):
|
|
|
66 |
aligned_embeddings = []
|
67 |
current_embedding_index = 0
|
68 |
|