CosyVoice commited on
Commit
cd26f11
·
unverified ·
2 Parent(s): f6b5c42 9e0b99e

Merge pull request #379 from boji123/bj_dev_stream_fix

Browse files
cosyvoice/cli/model.py CHANGED
@@ -50,6 +50,7 @@ class CosyVoiceModel:
50
  self.llm_end_dict = {}
51
  self.mel_overlap_dict = {}
52
  self.hift_cache_dict = {}
 
53
 
54
  def load(self, llm_model, flow_model, hift_model):
55
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -114,13 +115,20 @@ class CosyVoiceModel:
114
  self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
115
  tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
116
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
117
- self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
 
 
 
 
 
118
  tts_speech = tts_speech[:, :-self.source_cache_len]
119
  else:
120
  if speed != 1.0:
121
  assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
122
  tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
123
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
 
 
124
  return tts_speech
125
 
126
  def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
 
50
  self.llm_end_dict = {}
51
  self.mel_overlap_dict = {}
52
  self.hift_cache_dict = {}
53
+ self.speech_window = np.hamming(2 * self.source_cache_len)
54
 
55
  def load(self, llm_model, flow_model, hift_model):
56
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
 
115
  self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
116
  tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
117
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
118
+ if self.hift_cache_dict[uuid] is not None:
119
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
120
+ self.hift_cache_dict[uuid] = {
121
+ 'mel': tts_mel[:, :, -self.mel_cache_len:],
122
+ 'source': tts_source[:, :, -self.source_cache_len:],
123
+ 'speech': tts_speech[:, -self.source_cache_len:]}
124
  tts_speech = tts_speech[:, :-self.source_cache_len]
125
  else:
126
  if speed != 1.0:
127
  assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
128
  tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
129
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
130
+ if self.hift_cache_dict[uuid] is not None:
131
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
132
  return tts_speech
133
 
134
  def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
cosyvoice/utils/common.py CHANGED
@@ -139,6 +139,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
139
  device = fade_in_mel.device
140
  fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
141
  mel_overlap_len = int(window.shape[0] / 2)
142
- fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
143
- fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
 
144
  return fade_in_mel.to(device)
 
139
  device = fade_in_mel.device
140
  fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
141
  mel_overlap_len = int(window.shape[0] / 2)
142
+
143
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
144
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
145
  return fade_in_mel.to(device)