hengjie yang commited on
Commit
4ecc033
·
1 Parent(s): 3544cbd

Improve tensor dimension handling and add debug info

Browse files
Files changed (1) hide show
  1. src/deploy/voice_clone.py +24 -11
src/deploy/voice_clone.py CHANGED
@@ -71,14 +71,19 @@ class VoiceCloneSystem:
71
  # 提取特征
72
  with torch.no_grad():
73
  embedding = self.speaker_encoder.encode_batch(waveform.to(self.device))
74
- # 调整维度
75
- embedding = embedding.squeeze(0) # 移除批次维度
 
 
76
  embeddings.append(embedding)
77
 
78
  # 计算平均特征
79
  mean_embedding = torch.mean(torch.stack(embeddings), dim=0)
80
- # 调整维度以匹配模型要求
81
- mean_embedding = mean_embedding.view(1, -1) # [1, 512]
 
 
 
82
  return mean_embedding
83
 
84
  def generate_speech(
@@ -99,6 +104,10 @@ class VoiceCloneSystem:
99
  # 处理输入文本
100
  inputs = self.processor(text=text, return_tensors="pt")
101
 
 
 
 
 
102
  # 生成语音
103
  speech = self.tts_model.generate_speech(
104
  inputs["input_ids"].to(self.device),
@@ -123,13 +132,17 @@ class VoiceCloneSystem:
123
  Returns:
124
  生成的语音波形
125
  """
126
- # 1. 提取说话人特征
127
- speaker_embedding = self.extract_speaker_embedding(reference_audio_paths)
128
-
129
- # 2. 生成语音
130
- speech = self.generate_speech(text, speaker_embedding)
131
-
132
- return speech
 
 
 
 
133
 
134
  def save_audio(
135
  self,
 
71
  # 提取特征
72
  with torch.no_grad():
73
  embedding = self.speaker_encoder.encode_batch(waveform.to(self.device))
74
+ # 调整维度:从 [1, 1, 1, 512] 转换为 [1, 512]
75
+ embedding = embedding.squeeze() # 移除所有维度为1的维度
76
+ if embedding.dim() == 1:
77
+ embedding = embedding.unsqueeze(0) # 确保是 [1, 512]
78
  embeddings.append(embedding)
79
 
80
  # 计算平均特征
81
  mean_embedding = torch.mean(torch.stack(embeddings), dim=0)
82
+ if mean_embedding.dim() == 1:
83
+ mean_embedding = mean_embedding.unsqueeze(0) # 确保是 [1, 512]
84
+
85
+ # 打印维度信息以便调试
86
+ print(f"Final embedding shape: {mean_embedding.shape}")
87
  return mean_embedding
88
 
89
  def generate_speech(
 
104
  # 处理输入文本
105
  inputs = self.processor(text=text, return_tensors="pt")
106
 
107
+ # 确保说话人特征维度正确
108
+ if speaker_embedding.dim() != 2 or speaker_embedding.size(1) != 512:
109
+ raise ValueError(f"Speaker embedding should have shape [1, 512], but got {speaker_embedding.shape}")
110
+
111
  # 生成语音
112
  speech = self.tts_model.generate_speech(
113
  inputs["input_ids"].to(self.device),
 
132
  Returns:
133
  生成的语音波形
134
  """
135
+ try:
136
+ # 1. 提取说话人特征
137
+ speaker_embedding = self.extract_speaker_embedding(reference_audio_paths)
138
+
139
+ # 2. 生成语音
140
+ speech = self.generate_speech(text, speaker_embedding)
141
+
142
+ return speech
143
+ except Exception as e:
144
+ print(f"Error in clone_voice: {str(e)}")
145
+ raise
146
 
147
  def save_audio(
148
  self,