Spaces:
Sleeping
Sleeping
hengjie yang
commited on
Commit
·
4ecc033
1
Parent(s):
3544cbd
Improve tensor dimension handling and add debug info
Browse files- src/deploy/voice_clone.py +24 -11
src/deploy/voice_clone.py
CHANGED
@@ -71,14 +71,19 @@ class VoiceCloneSystem:
|
|
71 |
# 提取特征
|
72 |
with torch.no_grad():
|
73 |
embedding = self.speaker_encoder.encode_batch(waveform.to(self.device))
|
74 |
-
#
|
75 |
-
embedding = embedding.squeeze(
|
|
|
|
|
76 |
embeddings.append(embedding)
|
77 |
|
78 |
# 计算平均特征
|
79 |
mean_embedding = torch.mean(torch.stack(embeddings), dim=0)
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
return mean_embedding
|
83 |
|
84 |
def generate_speech(
|
@@ -99,6 +104,10 @@ class VoiceCloneSystem:
|
|
99 |
# 处理输入文本
|
100 |
inputs = self.processor(text=text, return_tensors="pt")
|
101 |
|
|
|
|
|
|
|
|
|
102 |
# 生成语音
|
103 |
speech = self.tts_model.generate_speech(
|
104 |
inputs["input_ids"].to(self.device),
|
@@ -123,13 +132,17 @@ class VoiceCloneSystem:
|
|
123 |
Returns:
|
124 |
生成的语音波形
|
125 |
"""
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def save_audio(
|
135 |
self,
|
|
|
71 |
# 提取特征
|
72 |
with torch.no_grad():
|
73 |
embedding = self.speaker_encoder.encode_batch(waveform.to(self.device))
|
74 |
+
# 调整维度:从 [1, 1, 1, 512] 转换为 [1, 512]
|
75 |
+
embedding = embedding.squeeze() # 移除所有维度为1的维度
|
76 |
+
if embedding.dim() == 1:
|
77 |
+
embedding = embedding.unsqueeze(0) # 确保是 [1, 512]
|
78 |
embeddings.append(embedding)
|
79 |
|
80 |
# 计算平均特征
|
81 |
mean_embedding = torch.mean(torch.stack(embeddings), dim=0)
|
82 |
+
if mean_embedding.dim() == 1:
|
83 |
+
mean_embedding = mean_embedding.unsqueeze(0) # 确保是 [1, 512]
|
84 |
+
|
85 |
+
# 打印维度信息以便调试
|
86 |
+
print(f"Final embedding shape: {mean_embedding.shape}")
|
87 |
return mean_embedding
|
88 |
|
89 |
def generate_speech(
|
|
|
104 |
# 处理输入文本
|
105 |
inputs = self.processor(text=text, return_tensors="pt")
|
106 |
|
107 |
+
# 确保说话人特征维度正确
|
108 |
+
if speaker_embedding.dim() != 2 or speaker_embedding.size(1) != 512:
|
109 |
+
raise ValueError(f"Speaker embedding should have shape [1, 512], but got {speaker_embedding.shape}")
|
110 |
+
|
111 |
# 生成语音
|
112 |
speech = self.tts_model.generate_speech(
|
113 |
inputs["input_ids"].to(self.device),
|
|
|
132 |
Returns:
|
133 |
生成的语音波形
|
134 |
"""
|
135 |
+
try:
|
136 |
+
# 1. 提取说话人特征
|
137 |
+
speaker_embedding = self.extract_speaker_embedding(reference_audio_paths)
|
138 |
+
|
139 |
+
# 2. 生成语音
|
140 |
+
speech = self.generate_speech(text, speaker_embedding)
|
141 |
+
|
142 |
+
return speech
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Error in clone_voice: {str(e)}")
|
145 |
+
raise
|
146 |
|
147 |
def save_audio(
|
148 |
self,
|