CosyVoice commited on
Commit
20e0715
·
unverified ·
2 Parent(s): dcc943d 6620129

Merge pull request #327 from FunAudioLLM/inference_streaming

Browse files
.gitignore CHANGED
@@ -43,6 +43,8 @@ compile_commands.json
43
 
44
  # train/inference files
45
  *.wav
 
 
46
  *.pt
47
  pretrained_models/*
48
  *_pb2_grpc.py
 
43
 
44
  # train/inference files
45
  *.wav
46
+ *.m4a
47
+ *.aac
48
  *.pt
49
  pretrained_models/*
50
  *_pb2_grpc.py
README.md CHANGED
@@ -116,23 +116,24 @@ import torchaudio
116
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
117
  # sft usage
118
  print(cosyvoice.list_avaliable_spks())
119
- output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
120
- torchaudio.save('sft.wav', output['tts_speech'], 22050)
 
121
 
122
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
123
  # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
124
  prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
125
- output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
126
- torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
127
  # cross_lingual usage
128
  prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
129
- output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
130
- torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
131
 
132
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
133
  # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
134
- output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
135
- torchaudio.save('instruct.wav', output['tts_speech'], 22050)
136
  ```
137
 
138
  **Start web demo**
@@ -163,10 +164,10 @@ docker build -t cosyvoice:v1.0 .
163
  # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
164
  # for grpc usage
165
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
166
- python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
167
  # for fastapi usage
168
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
169
- python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
170
  ```
171
 
172
  ## Discussion & Communication
 
116
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
117
  # sft usage
118
  print(cosyvoice.list_avaliable_spks())
119
+ # change stream=True for chunk stream inference
120
+ for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
121
+ torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
122
 
123
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
124
  # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
125
  prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
126
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
127
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
128
  # cross_lingual usage
129
  prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
130
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
131
+ torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
132
 
133
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
134
  # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
135
+ for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
136
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
137
  ```
138
 
139
  **Start web demo**
 
164
  # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
165
  # for grpc usage
166
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
167
+ cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
168
  # for fastapi usage
169
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
170
+ cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
171
  ```
172
 
173
  ## Discussion & Communication
cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
23
+ sys.path.append('{}/../..'.format(ROOT_DIR))
24
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
25
+ import torch
26
+ from cosyvoice.cli.cosyvoice import CosyVoice
27
+
28
+ def get_args():
29
+ parser = argparse.ArgumentParser(description='export your model for deployment')
30
+ parser.add_argument('--model_dir',
31
+ type=str,
32
+ default='pretrained_models/CosyVoice-300M',
33
+ help='local path')
34
+ args = parser.parse_args()
35
+ print(args)
36
+ return args
37
+
38
+ def main():
39
+ args = get_args()
40
+ logging.basicConfig(level=logging.DEBUG,
41
+ format='%(asctime)s %(levelname)s %(message)s')
42
+
43
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
44
+ torch._C._jit_set_profiling_mode(False)
45
+ torch._C._jit_set_profiling_executor(False)
46
+
47
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
48
+
49
+ # 1. export llm text_encoder
50
+ llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
51
+ script = torch.jit.script(llm_text_encoder)
52
+ script = torch.jit.freeze(script)
53
+ script = torch.jit.optimize_for_inference(script)
54
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
55
+
56
+ # 2. export llm llm
57
+ llm_llm = cosyvoice.model.llm.llm.half()
58
+ script = torch.jit.script(llm_llm)
59
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
60
+ script = torch.jit.optimize_for_inference(script)
61
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
62
+
63
+ if __name__ == '__main__':
64
+ main()
cosyvoice/bin/export_trt.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
2
+ # tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
3
+ try:
4
+ import tensorrt
5
+ except ImportError:
6
+ print('step1, 下载\n step2. 解压,安装whl,')
7
+ # 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
8
+ # 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
cosyvoice/bin/inference.py CHANGED
@@ -100,10 +100,13 @@ def main():
100
  'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
  'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
  'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
- model_output = model.inference(**model_input)
 
 
 
104
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
  tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
- torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
107
  f.write('{} {}\n'.format(tts_key, tts_fn))
108
  f.flush()
109
  f.close()
 
100
  'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
  'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
  'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
+ tts_speeches = []
104
+ for model_output in model.inference(**model_input):
105
+ tts_speeches.append(model_output['tts_speech'])
106
+ tts_speeches = torch.concat(tts_speeches, dim=1)
107
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
108
  tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
109
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
110
  f.write('{} {}\n'.format(tts_key, tts_fn))
111
  f.flush()
112
  f.close()
cosyvoice/cli/cosyvoice.py CHANGED
@@ -12,15 +12,16 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import os
15
- import torch
16
  from hyperpyyaml import load_hyperpyyaml
17
  from modelscope import snapshot_download
18
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
  from cosyvoice.cli.model import CosyVoiceModel
 
20
 
21
  class CosyVoice:
22
 
23
- def __init__(self, model_dir):
24
  instruct = True if '-Instruct' in model_dir else False
25
  self.model_dir = model_dir
26
  if not os.path.exists(model_dir):
@@ -38,46 +39,61 @@ class CosyVoice:
38
  self.model.load('{}/llm.pt'.format(model_dir),
39
  '{}/flow.pt'.format(model_dir),
40
  '{}/hift.pt'.format(model_dir))
 
 
 
41
  del configs
42
 
43
  def list_avaliable_spks(self):
44
  spks = list(self.frontend.spk2info.keys())
45
  return spks
46
 
47
- def inference_sft(self, tts_text, spk_id):
48
- tts_speeches = []
49
  for i in self.frontend.text_normalize(tts_text, split=True):
50
  model_input = self.frontend.frontend_sft(i, spk_id)
51
- model_output = self.model.inference(**model_input)
52
- tts_speeches.append(model_output['tts_speech'])
53
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
 
 
54
 
55
- def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
56
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
- tts_speeches = []
58
  for i in self.frontend.text_normalize(tts_text, split=True):
59
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
- model_output = self.model.inference(**model_input)
61
- tts_speeches.append(model_output['tts_speech'])
62
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
 
 
63
 
64
- def inference_cross_lingual(self, tts_text, prompt_speech_16k):
65
  if self.frontend.instruct is True:
66
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
- tts_speeches = []
68
  for i in self.frontend.text_normalize(tts_text, split=True):
69
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
- model_output = self.model.inference(**model_input)
71
- tts_speeches.append(model_output['tts_speech'])
72
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
 
 
73
 
74
- def inference_instruct(self, tts_text, spk_id, instruct_text):
75
  if self.frontend.instruct is False:
76
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
77
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
- tts_speeches = []
79
  for i in self.frontend.text_normalize(tts_text, split=True):
80
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
- model_output = self.model.inference(**model_input)
82
- tts_speeches.append(model_output['tts_speech'])
83
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import os
15
+ import time
16
  from hyperpyyaml import load_hyperpyyaml
17
  from modelscope import snapshot_download
18
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
  from cosyvoice.cli.model import CosyVoiceModel
20
+ from cosyvoice.utils.file_utils import logging
21
 
22
  class CosyVoice:
23
 
24
+ def __init__(self, model_dir, load_jit=True):
25
  instruct = True if '-Instruct' in model_dir else False
26
  self.model_dir = model_dir
27
  if not os.path.exists(model_dir):
 
39
  self.model.load('{}/llm.pt'.format(model_dir),
40
  '{}/flow.pt'.format(model_dir),
41
  '{}/hift.pt'.format(model_dir))
42
+ if load_jit:
43
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
+ '{}/llm.llm.fp16.zip'.format(model_dir))
45
  del configs
46
 
47
  def list_avaliable_spks(self):
48
  spks = list(self.frontend.spk2info.keys())
49
  return spks
50
 
51
+ def inference_sft(self, tts_text, spk_id, stream=False):
 
52
  for i in self.frontend.text_normalize(tts_text, split=True):
53
  model_input = self.frontend.frontend_sft(i, spk_id)
54
+ start_time = time.time()
55
+ logging.info('synthesis text {}'.format(i))
56
+ for model_output in self.model.inference(**model_input, stream=stream):
57
+ speech_len = model_output['tts_speech'].shape[1] / 22050
58
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
59
+ yield model_output
60
+ start_time = time.time()
61
 
62
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
63
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
 
64
  for i in self.frontend.text_normalize(tts_text, split=True):
65
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
66
+ start_time = time.time()
67
+ logging.info('synthesis text {}'.format(i))
68
+ for model_output in self.model.inference(**model_input, stream=stream):
69
+ speech_len = model_output['tts_speech'].shape[1] / 22050
70
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
71
+ yield model_output
72
+ start_time = time.time()
73
 
74
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
75
  if self.frontend.instruct is True:
76
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
 
77
  for i in self.frontend.text_normalize(tts_text, split=True):
78
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
79
+ start_time = time.time()
80
+ logging.info('synthesis text {}'.format(i))
81
+ for model_output in self.model.inference(**model_input, stream=stream):
82
+ speech_len = model_output['tts_speech'].shape[1] / 22050
83
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
84
+ yield model_output
85
+ start_time = time.time()
86
 
87
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
88
  if self.frontend.instruct is False:
89
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
90
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
 
91
  for i in self.frontend.text_normalize(tts_text, split=True):
92
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
93
+ start_time = time.time()
94
+ logging.info('synthesis text {}'.format(i))
95
+ for model_output in self.model.inference(**model_input, stream=stream):
96
+ speech_len = model_output['tts_speech'].shape[1] / 22050
97
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
98
+ yield model_output
99
+ start_time = time.time()
cosyvoice/cli/model.py CHANGED
@@ -12,6 +12,13 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import torch
 
 
 
 
 
 
 
15
 
16
  class CosyVoiceModel:
17
 
@@ -23,38 +30,143 @@ class CosyVoiceModel:
23
  self.llm = llm
24
  self.flow = flow
25
  self.hift = hift
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def load(self, llm_model, flow_model, hift_model):
28
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
29
  self.llm.to(self.device).eval()
 
30
  self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
31
  self.flow.to(self.device).eval()
32
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
33
  self.hift.to(self.device).eval()
34
 
35
- def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
36
- prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
- llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
- flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
- prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
- tts_speech_token = self.llm.inference(text=text.to(self.device),
41
- text_len=text_len.to(self.device),
42
- prompt_text=prompt_text.to(self.device),
43
- prompt_text_len=prompt_text_len.to(self.device),
44
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
- prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
- embedding=llm_embedding.to(self.device),
47
- beam_size=1,
48
- sampling=25,
49
- max_token_text_ratio=30,
50
- min_token_text_ratio=3)
51
- tts_mel = self.flow.inference(token=tts_speech_token,
52
- token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
- prompt_token=flow_prompt_speech_token.to(self.device),
54
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
- prompt_feat=prompt_speech_feat.to(self.device),
56
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
- embedding=flow_embedding.to(self.device))
58
- tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
- torch.cuda.empty_cache()
60
- return {'tts_speech': tts_speech}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import torch
15
+ import numpy as np
16
+ import threading
17
+ import time
18
+ from contextlib import nullcontext
19
+ import uuid
20
+ from cosyvoice.utils.common import fade_in_out
21
+
22
 
23
  class CosyVoiceModel:
24
 
 
30
  self.llm = llm
31
  self.flow = flow
32
  self.hift = hift
33
+ self.token_min_hop_len = 100
34
+ self.token_max_hop_len = 200
35
+ self.token_overlap_len = 20
36
+ # mel fade in out
37
+ self.mel_overlap_len = 34
38
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
39
+ # hift cache
40
+ self.mel_cache_len = 20
41
+ self.source_cache_len = int(self.mel_cache_len * 256)
42
+ # rtf and decoding related
43
+ self.stream_scale_factor = 1
44
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
45
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
46
+ self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
47
+ self.lock = threading.Lock()
48
+ # dict used to store session related variable
49
+ self.tts_speech_token_dict = {}
50
+ self.llm_end_dict = {}
51
+ self.mel_overlap_dict = {}
52
+ self.hift_cache_dict = {}
53
 
54
  def load(self, llm_model, flow_model, hift_model):
55
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
56
  self.llm.to(self.device).eval()
57
+ self.llm.half()
58
  self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
59
  self.flow.to(self.device).eval()
60
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
61
  self.hift.to(self.device).eval()
62
 
63
+ def load_jit(self, llm_text_encoder_model, llm_llm_model):
64
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model)
65
+ self.llm.text_encoder = llm_text_encoder
66
+ llm_llm = torch.jit.load(llm_llm_model)
67
+ self.llm.llm = llm_llm
68
+
69
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
70
+ with self.llm_context:
71
+ for i in self.llm.inference(text=text.to(self.device),
72
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
73
+ prompt_text=prompt_text.to(self.device),
74
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
75
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
76
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
77
+ embedding=llm_embedding.to(self.device).half(),
78
+ sampling=25,
79
+ max_token_text_ratio=30,
80
+ min_token_text_ratio=3):
81
+ self.tts_speech_token_dict[uuid].append(i)
82
+ self.llm_end_dict[uuid] = True
83
+
84
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
85
+ with self.flow_hift_context:
86
+ tts_mel = self.flow.inference(token=token.to(self.device),
87
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
88
+ prompt_token=prompt_token.to(self.device),
89
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
90
+ prompt_feat=prompt_feat.to(self.device),
91
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
92
+ embedding=embedding.to(self.device))
93
+ # mel overlap fade in out
94
+ if self.mel_overlap_dict[uuid] is not None:
95
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
96
+ # append hift cache
97
+ if self.hift_cache_dict[uuid] is not None:
98
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
99
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
100
+ else:
101
+ hift_cache_source = torch.zeros(1, 1, 0)
102
+ # keep overlap mel and hift cache
103
+ if finalize is False:
104
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
105
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
106
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
107
+ self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
108
+ tts_speech = tts_speech[:, :-self.source_cache_len]
109
+ else:
110
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
111
+ return tts_speech
112
+
113
+ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
114
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
115
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
116
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
117
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
118
+ # this_uuid is used to track variables related to this inference thread
119
+ this_uuid = str(uuid.uuid1())
120
+ with self.lock:
121
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
122
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
123
+ p.start()
124
+ if stream is True:
125
+ token_hop_len = self.token_min_hop_len
126
+ while True:
127
+ time.sleep(0.1)
128
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
129
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
130
+ with self.flow_hift_context:
131
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
132
+ prompt_token=flow_prompt_speech_token,
133
+ prompt_feat=prompt_speech_feat,
134
+ embedding=flow_embedding,
135
+ uuid=this_uuid,
136
+ finalize=False)
137
+ yield {'tts_speech': this_tts_speech.cpu()}
138
+ with self.lock:
139
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
140
+ # increase token_hop_len for better speech quality
141
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
142
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
143
+ break
144
+ p.join()
145
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
146
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
147
+ with self.flow_hift_context:
148
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
149
+ prompt_token=flow_prompt_speech_token,
150
+ prompt_feat=prompt_speech_feat,
151
+ embedding=flow_embedding,
152
+ uuid=this_uuid,
153
+ finalize=True)
154
+ yield {'tts_speech': this_tts_speech.cpu()}
155
+ else:
156
+ # deal with all tokens
157
+ p.join()
158
+ this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
159
+ with self.flow_hift_context:
160
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
161
+ prompt_token=flow_prompt_speech_token,
162
+ prompt_feat=prompt_speech_feat,
163
+ embedding=flow_embedding,
164
+ uuid=this_uuid,
165
+ finalize=True)
166
+ yield {'tts_speech': this_tts_speech.cpu()}
167
+ with self.lock:
168
+ self.tts_speech_token_dict.pop(this_uuid)
169
+ self.llm_end_dict.pop(this_uuid)
170
+ self.mel_overlap_dict.pop(this_uuid)
171
+ self.hift_cache_dict.pop(this_uuid)
172
+ torch.cuda.synchronize()
cosyvoice/flow/flow.py CHANGED
@@ -111,6 +111,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
111
  embedding = self.spk_embed_affine_layer(embedding)
112
 
113
  # concat text and prompt_text
 
114
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
115
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
116
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -118,17 +119,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
118
  # text encode
119
  h, h_lengths = self.encoder(token, token_len)
120
  h = self.encoder_proj(h)
121
- feat_len = (token_len / 50 * 22050 / 256).int()
122
- h, h_lengths = self.length_regulator(h, feat_len)
123
 
124
  # get conditions
125
- conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
126
- if prompt_feat.shape[1] != 0:
127
- for i, j in enumerate(prompt_feat_len):
128
- conds[i, :j] = prompt_feat[i]
129
  conds = conds.transpose(1, 2)
130
 
131
- mask = (~make_pad_mask(feat_len)).to(h)
 
132
  feat = self.decoder(
133
  mu=h.transpose(1, 2).contiguous(),
134
  mask=mask.unsqueeze(1),
@@ -136,6 +136,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
136
  cond=conds,
137
  n_timesteps=10
138
  )
139
- if prompt_feat.shape[1] != 0:
140
- feat = feat[:, :, prompt_feat.shape[1]:]
141
  return feat
 
111
  embedding = self.spk_embed_affine_layer(embedding)
112
 
113
  # concat text and prompt_text
114
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
115
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
116
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
117
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
119
  # text encode
120
  h, h_lengths = self.encoder(token, token_len)
121
  h = self.encoder_proj(h)
122
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
123
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
124
 
125
  # get conditions
126
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
127
+ conds[:, :mel_len1] = prompt_feat
 
 
128
  conds = conds.transpose(1, 2)
129
 
130
+ # mask = (~make_pad_mask(feat_len)).to(h)
131
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
132
  feat = self.decoder(
133
  mu=h.transpose(1, 2).contiguous(),
134
  mask=mask.unsqueeze(1),
 
136
  cond=conds,
137
  n_timesteps=10
138
  )
139
+ feat = feat[:, :, mel_len1:]
140
+ assert feat.shape[2] == mel_len2
141
  return feat
cosyvoice/flow/length_regulator.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  from typing import Tuple
15
  import torch.nn as nn
 
16
  from torch.nn import functional as F
17
  from cosyvoice.utils.mask import make_pad_mask
18
 
@@ -43,7 +44,25 @@ class InterpolateRegulator(nn.Module):
43
  def forward(self, x, ylens=None):
44
  # x in (B, T, D)
45
  mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
46
- x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
47
  out = self.model(x).transpose(1, 2).contiguous()
48
  olens = ylens
49
  return out * mask, olens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  from typing import Tuple
15
  import torch.nn as nn
16
+ import torch
17
  from torch.nn import functional as F
18
  from cosyvoice.utils.mask import make_pad_mask
19
 
 
44
  def forward(self, x, ylens=None):
45
  # x in (B, T, D)
46
  mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
  out = self.model(x).transpose(1, 2).contiguous()
49
  olens = ylens
50
  return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
58
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
59
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
60
+ else:
61
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
62
+ if x1.shape[1] != 0:
63
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
64
+ x = torch.concat([x1, x2], dim=2)
65
+ else:
66
+ x = x2
67
+ out = self.model(x).transpose(1, 2).contiguous()
68
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/generator.py CHANGED
@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
335
  inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
  return inverse_transform
337
 
338
- def forward(self, x: torch.Tensor) -> torch.Tensor:
339
  f0 = self.f0_predictor(x)
340
  s = self._f02source(f0)
341
 
 
 
 
 
342
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
343
  s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
344
 
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
370
 
371
  x = self._istft(magnitude, phase)
372
  x = torch.clamp(x, -self.audio_limit, self.audio_limit)
373
- return x
374
 
375
  def remove_weight_norm(self):
376
  print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
387
  l.remove_weight_norm()
388
 
389
  @torch.inference_mode()
390
- def inference(self, mel: torch.Tensor) -> torch.Tensor:
391
- return self.forward(x=mel)
 
335
  inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
336
  return inverse_transform
337
 
338
+ def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
339
  f0 = self.f0_predictor(x)
340
  s = self._f02source(f0)
341
 
342
+ # use cache_source to avoid glitch
343
+ if cache_source.shape[2] == 0:
344
+ s[:, :, :cache_source.shape[2]] = cache_source
345
+
346
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
347
  s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
348
 
 
374
 
375
  x = self._istft(magnitude, phase)
376
  x = torch.clamp(x, -self.audio_limit, self.audio_limit)
377
+ return x, s
378
 
379
  def remove_weight_norm(self):
380
  print('Removing weight norm...')
 
391
  l.remove_weight_norm()
392
 
393
  @torch.inference_mode()
394
+ def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
395
+ return self.forward(x=mel, cache_source=cache_source)
cosyvoice/llm/llm.py CHANGED
@@ -11,7 +11,7 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- from typing import Dict, Optional, Union
15
  import torch
16
  from torch import nn
17
  import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
31
  speech_token_size: int,
32
  text_encoder: torch.nn.Module,
33
  llm: torch.nn.Module,
 
34
  length_normalized_loss: bool = True,
35
  lsm_weight: float = 0.0,
36
  spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
63
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
64
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
65
 
 
 
 
66
  def encode(
67
  self,
68
  text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
132
  def sampling_ids(
133
  self,
134
  weighted_scores: torch.Tensor,
135
- sampling: Union[bool, int, float] = True,
136
- beam_size: int = 1,
137
  ignore_eos: bool = True,
138
  ):
139
  while True:
140
- prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
141
- top_ids = prob.multinomial(beam_size, replacement=True)
142
- top_ids = indices[top_ids]
143
  if (not ignore_eos) or (self.speech_token_size not in top_ids):
144
  break
145
  return top_ids
@@ -154,11 +156,10 @@ class TransformerLM(torch.nn.Module):
154
  prompt_speech_token: torch.Tensor,
155
  prompt_speech_token_len: torch.Tensor,
156
  embedding: torch.Tensor,
157
- beam_size: int = 1,
158
  sampling: int = 25,
159
  max_token_text_ratio: float = 20,
160
  min_token_text_ratio: float = 2,
161
- ) -> torch.Tensor:
162
  device = text.device
163
  text = torch.concat([prompt_text, text], dim=1)
164
  text_len += prompt_text_len
@@ -173,7 +174,7 @@ class TransformerLM(torch.nn.Module):
173
  embedding = self.spk_embed_affine_layer(embedding)
174
  embedding = embedding.unsqueeze(dim=1)
175
  else:
176
- embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
177
 
178
  # 3. concat llm_input
179
  sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -181,7 +182,7 @@ class TransformerLM(torch.nn.Module):
181
  if prompt_speech_token_len != 0:
182
  prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
183
  else:
184
- prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
185
  lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
186
 
187
  # 4. cal min/max_length
@@ -196,11 +197,11 @@ class TransformerLM(torch.nn.Module):
196
  y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
197
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
198
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
199
- top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
200
  if top_ids == self.speech_token_size:
201
  break
 
 
202
  out_tokens.append(top_ids)
203
  offset += lm_input.size(1)
204
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
205
-
206
- return torch.tensor([out_tokens], dtype=torch.int64, device=device)
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ from typing import Dict, Optional, Callable, List, Generator
15
  import torch
16
  from torch import nn
17
  import torch.nn.functional as F
 
31
  speech_token_size: int,
32
  text_encoder: torch.nn.Module,
33
  llm: torch.nn.Module,
34
+ sampling: Callable,
35
  length_normalized_loss: bool = True,
36
  lsm_weight: float = 0.0,
37
  spk_embed_dim: int = 192,
 
64
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
65
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
66
 
67
+ # 4. sampling method
68
+ self.sampling = sampling
69
+
70
  def encode(
71
  self,
72
  text: torch.Tensor,
 
136
  def sampling_ids(
137
  self,
138
  weighted_scores: torch.Tensor,
139
+ decoded_tokens: List,
140
+ sampling: int,
141
  ignore_eos: bool = True,
142
  ):
143
  while True:
144
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
 
 
145
  if (not ignore_eos) or (self.speech_token_size not in top_ids):
146
  break
147
  return top_ids
 
156
  prompt_speech_token: torch.Tensor,
157
  prompt_speech_token_len: torch.Tensor,
158
  embedding: torch.Tensor,
 
159
  sampling: int = 25,
160
  max_token_text_ratio: float = 20,
161
  min_token_text_ratio: float = 2,
162
+ ) -> Generator[torch.Tensor, None, None]:
163
  device = text.device
164
  text = torch.concat([prompt_text, text], dim=1)
165
  text_len += prompt_text_len
 
174
  embedding = self.spk_embed_affine_layer(embedding)
175
  embedding = embedding.unsqueeze(dim=1)
176
  else:
177
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
178
 
179
  # 3. concat llm_input
180
  sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
 
182
  if prompt_speech_token_len != 0:
183
  prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
184
  else:
185
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
186
  lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
187
 
188
  # 4. cal min/max_length
 
197
  y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
198
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
199
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
200
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
201
  if top_ids == self.speech_token_size:
202
  break
203
+ # in stream mode, yield token one by one
204
+ yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
205
  out_tokens.append(top_ids)
206
  offset += lm_input.size(1)
207
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
 
 
cosyvoice/transformer/attention.py CHANGED
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
222
  torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
  torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
 
225
- def rel_shift(self, x):
226
  """Compute relative positional encoding.
227
 
228
  Args:
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
233
  torch.Tensor: Output tensor.
234
 
235
  """
236
- zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
 
 
237
  x_padded = torch.cat([zero_pad, x], dim=-1)
238
 
239
- x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
 
 
240
  x = x_padded[:, :, 1:].view_as(x)[
241
  :, :, :, : x.size(-1) // 2 + 1
242
  ] # only keep the positions from 0 to time2
 
222
  torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
  torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
 
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
  """Compute relative positional encoding.
227
 
228
  Args:
 
233
  torch.Tensor: Output tensor.
234
 
235
  """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
  x_padded = torch.cat([zero_pad, x], dim=-1)
240
 
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
  x = x_padded[:, :, 1:].view_as(x)[
245
  :, :, :, : x.size(-1) // 2 + 1
246
  ] # only keep the positions from 0 to time2
cosyvoice/transformer/decoder.py CHANGED
@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
174
  memory_mask)
175
  return x
176
 
177
- @torch.jit.ignore(drop=True)
178
  def forward_layers_checkpointed(self, x: torch.Tensor,
179
  tgt_mask: torch.Tensor,
180
  memory: torch.Tensor,
 
174
  memory_mask)
175
  return x
176
 
177
+ @torch.jit.unused
178
  def forward_layers_checkpointed(self, x: torch.Tensor,
179
  tgt_mask: torch.Tensor,
180
  memory: torch.Tensor,
cosyvoice/transformer/embedding.py CHANGED
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
212
 
213
  """
214
 
215
- def __init__(self, d_model, dropout_rate, max_len=5000):
216
  """Construct an PositionalEncoding object."""
217
  super(EspnetRelPositionalEncoding, self).__init__()
218
  self.d_model = d_model
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
221
  self.pe = None
222
  self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
 
224
- def extend_pe(self, x):
225
  """Reset the positional encodings."""
226
  if self.pe is not None:
227
  # self.pe contains both positive and negative parts
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
253
  pe = torch.cat([pe_positive, pe_negative], dim=1)
254
  self.pe = pe.to(device=x.device, dtype=x.dtype)
255
 
256
- def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
 
257
  """Add positional encoding.
258
 
259
  Args:
 
212
 
213
  """
214
 
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
216
  """Construct an PositionalEncoding object."""
217
  super(EspnetRelPositionalEncoding, self).__init__()
218
  self.d_model = d_model
 
221
  self.pe = None
222
  self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
 
224
+ def extend_pe(self, x: torch.Tensor):
225
  """Reset the positional encodings."""
226
  if self.pe is not None:
227
  # self.pe contains both positive and negative parts
 
253
  pe = torch.cat([pe_positive, pe_negative], dim=1)
254
  self.pe = pe.to(device=x.device, dtype=x.dtype)
255
 
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
+ -> Tuple[torch.Tensor, torch.Tensor]:
258
  """Add positional encoding.
259
 
260
  Args:
cosyvoice/transformer/encoder.py CHANGED
@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
169
  xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170
  return xs
171
 
172
- @torch.jit.ignore(drop=True)
173
  def forward_layers_checkpointed(self, xs: torch.Tensor,
174
  chunk_masks: torch.Tensor,
175
  pos_emb: torch.Tensor,
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
180
  mask_pad)
181
  return xs
182
 
 
183
  def forward_chunk(
184
  self,
185
  xs: torch.Tensor,
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
270
 
271
  return (xs, r_att_cache, r_cnn_cache)
272
 
 
273
  def forward_chunk_by_chunk(
274
  self,
275
  xs: torch.Tensor,
 
169
  xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
170
  return xs
171
 
172
+ @torch.jit.unused
173
  def forward_layers_checkpointed(self, xs: torch.Tensor,
174
  chunk_masks: torch.Tensor,
175
  pos_emb: torch.Tensor,
 
180
  mask_pad)
181
  return xs
182
 
183
+ @torch.jit.export
184
  def forward_chunk(
185
  self,
186
  xs: torch.Tensor,
 
271
 
272
  return (xs, r_att_cache, r_cnn_cache)
273
 
274
+ @torch.jit.unused
275
  def forward_chunk_by_chunk(
276
  self,
277
  xs: torch.Tensor,
cosyvoice/utils/common.py CHANGED
@@ -101,3 +101,39 @@ def init_weights(m, mean=0.0, std=0.01):
101
  classname = m.__class__.__name__
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  classname = m.__class__.__name__
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
104
+
105
+ # Repetition Aware Sampling in VALL-E 2
106
+ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
107
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
108
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
109
+ if rep_num >= win_size * tau_r:
110
+ top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
111
+ return top_ids
112
+
113
+ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
114
+ prob, indices = [], []
115
+ cum_prob = 0.0
116
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
117
+ for i in range(len(sorted_idx)):
118
+ # sampling both top-p and numbers.
119
+ if cum_prob < top_p and len(prob) < top_k:
120
+ cum_prob += sorted_value[i]
121
+ prob.append(sorted_value[i])
122
+ indices.append(sorted_idx[i])
123
+ else:
124
+ break
125
+ prob = torch.tensor(prob).to(weighted_scores)
126
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
127
+ top_ids = indices[prob.multinomial(1, replacement=True)]
128
+ return top_ids
129
+
130
+ def random_sampling(weighted_scores, decoded_tokens, sampling):
131
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
132
+ return top_ids
133
+
134
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
135
+ device = fade_in_mel.device
136
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
137
+ mel_overlap_len = int(window.shape[0] / 2)
138
+ fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
139
+ return fade_in_mel.to(device)
cosyvoice/utils/file_utils.py CHANGED
@@ -15,6 +15,10 @@
15
 
16
  import json
17
  import torchaudio
 
 
 
 
18
 
19
 
20
  def read_lists(list_file):
 
15
 
16
  import json
17
  import torchaudio
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ logging.basicConfig(level=logging.DEBUG,
21
+ format='%(asctime)s %(levelname)s %(message)s')
22
 
23
 
24
  def read_lists(list_file):
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml CHANGED
@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
31
  num_blocks: 3
32
  dropout_rate: 0.1
33
  positional_dropout_rate: 0.1
34
- attention_dropout_rate: 0
35
  normalize_before: True
36
  input_layer: 'linear'
37
  pos_enc_layer_type: 'rel_pos_espnet'
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
49
  num_blocks: 7
50
  dropout_rate: 0.1
51
  positional_dropout_rate: 0.1
52
- attention_dropout_rate: 0
53
  input_layer: 'linear_legacy'
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
 
 
 
 
 
57
 
58
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
59
  input_size: 512
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
97
  in_channels: 320
98
  out_channels: 80
99
  channels: [256, 256]
100
- dropout: 0
101
  attention_head_dim: 64
102
  n_blocks: 4
103
  num_mid_blocks: 8
 
31
  num_blocks: 3
32
  dropout_rate: 0.1
33
  positional_dropout_rate: 0.1
34
+ attention_dropout_rate: 0.0
35
  normalize_before: True
36
  input_layer: 'linear'
37
  pos_enc_layer_type: 'rel_pos_espnet'
 
49
  num_blocks: 7
50
  dropout_rate: 0.1
51
  positional_dropout_rate: 0.1
52
+ attention_dropout_rate: 0.0
53
  input_layer: 'linear_legacy'
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
57
+ sampling: !name:cosyvoice.utils.common.ras_sampling
58
+ top_p: 0.8
59
+ top_k: 25
60
+ win_size: 10
61
+ tau_r: 0.1
62
 
63
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
64
  input_size: 512
 
102
  in_channels: 320
103
  out_channels: 80
104
  channels: [256, 256]
105
+ dropout: 0.0
106
  attention_head_dim: 64
107
  n_blocks: 4
108
  num_mid_blocks: 8
examples/libritts/cosyvoice/conf/cosyvoice.yaml CHANGED
@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
31
  num_blocks: 6
32
  dropout_rate: 0.1
33
  positional_dropout_rate: 0.1
34
- attention_dropout_rate: 0
35
  normalize_before: True
36
  input_layer: 'linear'
37
  pos_enc_layer_type: 'rel_pos_espnet'
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
49
  num_blocks: 14
50
  dropout_rate: 0.1
51
  positional_dropout_rate: 0.1
52
- attention_dropout_rate: 0
53
  input_layer: 'linear_legacy'
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
 
 
 
 
 
57
 
58
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
59
  input_size: 512
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
97
  in_channels: 320
98
  out_channels: 80
99
  channels: [256, 256]
100
- dropout: 0
101
  attention_head_dim: 64
102
  n_blocks: 4
103
  num_mid_blocks: 12
 
31
  num_blocks: 6
32
  dropout_rate: 0.1
33
  positional_dropout_rate: 0.1
34
+ attention_dropout_rate: 0.0
35
  normalize_before: True
36
  input_layer: 'linear'
37
  pos_enc_layer_type: 'rel_pos_espnet'
 
49
  num_blocks: 14
50
  dropout_rate: 0.1
51
  positional_dropout_rate: 0.1
52
+ attention_dropout_rate: 0.0
53
  input_layer: 'linear_legacy'
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
57
+ sampling: !name:cosyvoice.utils.common.ras_sampling
58
+ top_p: 0.8
59
+ top_k: 25
60
+ win_size: 10
61
+ tau_r: 0.1
62
 
63
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
64
  input_size: 512
 
102
  in_channels: 320
103
  out_channels: 80
104
  channels: [256, 256]
105
+ dropout: 0.0
106
  attention_head_dim: 64
107
  n_blocks: 4
108
  num_mid_blocks: 12
runtime/python/grpc/client.py CHANGED
@@ -61,8 +61,11 @@ def main():
61
  request.instruct_request.CopyFrom(instruct_request)
62
 
63
  response = stub.Inference(request)
 
 
 
 
64
  logging.info('save response to {}'.format(args.tts_wav))
65
- tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
66
  torchaudio.save(args.tts_wav, tts_speech, target_sr)
67
  logging.info('get response')
68
 
 
61
  request.instruct_request.CopyFrom(instruct_request)
62
 
63
  response = stub.Inference(request)
64
+ tts_audio = b''
65
+ for r in response:
66
+ tts_audio += r.tts_audio
67
+ tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
68
  logging.info('save response to {}'.format(args.tts_wav))
 
69
  torchaudio.save(args.tts_wav, tts_speech, target_sr)
70
  logging.info('get response')
71
 
runtime/python/grpc/cosyvoice.proto CHANGED
@@ -4,7 +4,7 @@ package cosyvoice;
4
  option go_package = "protos/";
5
 
6
  service CosyVoice{
7
- rpc Inference(Request) returns (Response) {}
8
  }
9
 
10
  message Request{
 
4
  option go_package = "protos/";
5
 
6
  service CosyVoice{
7
+ rpc Inference(Request) returns (stream Response) {}
8
  }
9
 
10
  message Request{
runtime/python/grpc/server.py CHANGED
@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
54
  model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
55
 
56
  logging.info('send inference response')
57
- response = cosyvoice_pb2.Response()
58
- response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
59
- return response
 
60
 
61
  def main():
62
  grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
 
54
  model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
55
 
56
  logging.info('send inference response')
57
+ for i in model_output:
58
+ response = cosyvoice_pb2.Response()
59
+ response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
60
+ yield response
61
 
62
  def main():
63
  grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
webui.py CHANGED
@@ -24,14 +24,8 @@ import torchaudio
24
  import random
25
  import librosa
26
 
27
- import logging
28
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
29
-
30
  from cosyvoice.cli.cosyvoice import CosyVoice
31
- from cosyvoice.utils.file_utils import load_wav, speed_change
32
-
33
- logging.basicConfig(level=logging.DEBUG,
34
- format='%(asctime)s %(levelname)s %(message)s')
35
 
36
  def generate_seed():
37
  seed = random.randint(1, 100000000)
@@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成
63
  '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
64
  '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
65
  '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
 
66
  def change_instruction(mode_checkbox_group):
67
  return instruct_dict[mode_checkbox_group]
68
 
69
- def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor):
70
  if prompt_wav_upload is not None:
71
  prompt_wav = prompt_wav_upload
72
  elif prompt_wav_record is not None:
@@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
117
  if mode_checkbox_group == '预训练音色':
118
  logging.info('get sft inference request')
119
  set_all_random_seed(seed)
120
- output = cosyvoice.inference_sft(tts_text, sft_dropdown)
 
121
  elif mode_checkbox_group == '3s极速复刻':
122
  logging.info('get zero_shot inference request')
123
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
124
  set_all_random_seed(seed)
125
- output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
 
126
  elif mode_checkbox_group == '跨语种复刻':
127
  logging.info('get cross_lingual inference request')
128
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
129
  set_all_random_seed(seed)
130
- output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
 
131
  else:
132
  logging.info('get instruct inference request')
133
  set_all_random_seed(seed)
134
- output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
135
-
136
- if speed_factor != 1.0:
137
- try:
138
- audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor))
139
- audio_data = audio_data.numpy().flatten()
140
- except Exception as e:
141
- print(f"Failed to change speed of audio: \n{e}")
142
- else:
143
- audio_data = output['tts_speech'].numpy().flatten()
144
-
145
- return (target_sr, audio_data)
146
 
147
  def main():
148
  with gr.Blocks() as demo:
@@ -155,6 +143,7 @@ def main():
155
  mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
156
  instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
157
  sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
 
158
  with gr.Column(scale=0.25):
159
  seed_button = gr.Button(value="\U0001F3B2")
160
  seed = gr.Number(value=0, label="随机推理种子")
@@ -167,11 +156,11 @@ def main():
167
 
168
  generate_button = gr.Button("生成音频")
169
 
170
- audio_output = gr.Audio(label="合成音频")
171
 
172
  seed_button.click(generate_seed, inputs=[], outputs=seed)
173
  generate_button.click(generate_audio,
174
- inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor],
175
  outputs=[audio_output])
176
  mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
177
  demo.queue(max_size=4, default_concurrency_limit=2)
@@ -184,7 +173,7 @@ if __name__ == '__main__':
184
  default=8000)
185
  parser.add_argument('--model_dir',
186
  type=str,
187
- default='iic/CosyVoice-300M',
188
  help='local path or modelscope repo id')
189
  args = parser.parse_args()
190
  cosyvoice = CosyVoice(args.model_dir)
 
24
  import random
25
  import librosa
26
 
 
 
 
27
  from cosyvoice.cli.cosyvoice import CosyVoice
28
+ from cosyvoice.utils.file_utils import load_wav, speed_change, logging
 
 
 
29
 
30
  def generate_seed():
31
  seed = random.randint(1, 100000000)
 
57
  '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
58
  '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
59
  '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
60
+ stream_mode_list = [('否', False), ('是', True)]
61
  def change_instruction(mode_checkbox_group):
62
  return instruct_dict[mode_checkbox_group]
63
 
64
+ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
65
  if prompt_wav_upload is not None:
66
  prompt_wav = prompt_wav_upload
67
  elif prompt_wav_record is not None:
 
112
  if mode_checkbox_group == '预训练音色':
113
  logging.info('get sft inference request')
114
  set_all_random_seed(seed)
115
+ for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
116
+ yield (target_sr, i['tts_speech'].numpy().flatten())
117
  elif mode_checkbox_group == '3s极速复刻':
118
  logging.info('get zero_shot inference request')
119
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
120
  set_all_random_seed(seed)
121
+ for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
122
+ yield (target_sr, i['tts_speech'].numpy().flatten())
123
  elif mode_checkbox_group == '跨语种复刻':
124
  logging.info('get cross_lingual inference request')
125
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
126
  set_all_random_seed(seed)
127
+ for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
128
+ yield (target_sr, i['tts_speech'].numpy().flatten())
129
  else:
130
  logging.info('get instruct inference request')
131
  set_all_random_seed(seed)
132
+ for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
133
+ yield (target_sr, i['tts_speech'].numpy().flatten())
 
 
 
 
 
 
 
 
 
 
134
 
135
  def main():
136
  with gr.Blocks() as demo:
 
143
  mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
144
  instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
145
  sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
146
+ stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
147
  with gr.Column(scale=0.25):
148
  seed_button = gr.Button(value="\U0001F3B2")
149
  seed = gr.Number(value=0, label="随机推理种子")
 
156
 
157
  generate_button = gr.Button("生成音频")
158
 
159
+ audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
160
 
161
  seed_button.click(generate_seed, inputs=[], outputs=seed)
162
  generate_button.click(generate_audio,
163
+ inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
164
  outputs=[audio_output])
165
  mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
166
  demo.queue(max_size=4, default_concurrency_limit=2)
 
173
  default=8000)
174
  parser.add_argument('--model_dir',
175
  type=str,
176
+ default='pretrained_models/CosyVoice-300M',
177
  help='local path or modelscope repo id')
178
  args = parser.parse_args()
179
  cosyvoice = CosyVoice(args.model_dir)