Spaces:
Running
Running
Create frontend.py
Browse files- frontend.py +215 -0
frontend.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Generator
|
| 16 |
+
import json
|
| 17 |
+
import onnxruntime
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import whisper
|
| 21 |
+
from typing import Callable
|
| 22 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 23 |
+
import torchaudio
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import inflect
|
| 27 |
+
try:
|
| 28 |
+
import ttsfrd
|
| 29 |
+
use_ttsfrd = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
| 32 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 33 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 34 |
+
use_ttsfrd = False
|
| 35 |
+
from cosyvoice.utils.file_utils import logging
|
| 36 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CosyVoiceFrontEnd:
|
| 40 |
+
|
| 41 |
+
def __init__(self,
|
| 42 |
+
get_tokenizer: Callable,
|
| 43 |
+
feat_extractor: Callable,
|
| 44 |
+
campplus_model: str,
|
| 45 |
+
speech_tokenizer_model: str,
|
| 46 |
+
spk2info: str = '',
|
| 47 |
+
allowed_special: str = 'all'):
|
| 48 |
+
self.tokenizer = get_tokenizer()
|
| 49 |
+
self.feat_extractor = feat_extractor
|
| 50 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 51 |
+
option = onnxruntime.SessionOptions()
|
| 52 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 53 |
+
option.intra_op_num_threads = 1
|
| 54 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 55 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 56 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 57 |
+
"CPUExecutionProvider"])
|
| 58 |
+
if os.path.exists(spk2info):
|
| 59 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 60 |
+
else:
|
| 61 |
+
self.spk2info = {}
|
| 62 |
+
self.allowed_special = allowed_special
|
| 63 |
+
self.use_ttsfrd = use_ttsfrd
|
| 64 |
+
if self.use_ttsfrd:
|
| 65 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 66 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 67 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 68 |
+
'failed to initialize ttsfrd resource'
|
| 69 |
+
self.frd.set_lang_type('pinyinvg')
|
| 70 |
+
else:
|
| 71 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
| 72 |
+
self.en_tn_model = EnNormalizer()
|
| 73 |
+
self.inflect_parser = inflect.engine()
|
| 74 |
+
|
| 75 |
+
def _extract_text_token(self, text):
|
| 76 |
+
if isinstance(text, Generator):
|
| 77 |
+
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
| 78 |
+
# NOTE add a dummy text_token_len for compatibility
|
| 79 |
+
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
| 80 |
+
else:
|
| 81 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 82 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 83 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 84 |
+
return text_token, text_token_len
|
| 85 |
+
|
| 86 |
+
def _extract_text_token_generator(self, text_generator):
|
| 87 |
+
for text in text_generator:
|
| 88 |
+
text_token, _ = self._extract_text_token(text)
|
| 89 |
+
for i in range(text_token.shape[1]):
|
| 90 |
+
yield text_token[:, i: i + 1]
|
| 91 |
+
|
| 92 |
+
def _extract_speech_token(self, speech):
|
| 93 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 94 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 95 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 96 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 97 |
+
feat.detach().cpu().numpy(),
|
| 98 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 99 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 100 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 101 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 102 |
+
return speech_token, speech_token_len
|
| 103 |
+
|
| 104 |
+
def _extract_spk_embedding(self, speech):
|
| 105 |
+
feat = kaldi.fbank(speech,
|
| 106 |
+
num_mel_bins=80,
|
| 107 |
+
dither=0,
|
| 108 |
+
sample_frequency=16000)
|
| 109 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 110 |
+
embedding = self.campplus_session.run(None,
|
| 111 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 112 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 113 |
+
return embedding
|
| 114 |
+
|
| 115 |
+
def _extract_speech_feat(self, speech):
|
| 116 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 117 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 118 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 119 |
+
return speech_feat, speech_feat_len
|
| 120 |
+
|
| 121 |
+
def text_normalize(self, text, split=True, text_frontend=True):
|
| 122 |
+
if isinstance(text, Generator):
|
| 123 |
+
logging.info('get tts_text generator, will skip text_normalize!')
|
| 124 |
+
return [text]
|
| 125 |
+
if text_frontend is False or text == '':
|
| 126 |
+
return [text] if split is True else text
|
| 127 |
+
text = text.strip()
|
| 128 |
+
if self.use_ttsfrd:
|
| 129 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 130 |
+
text = ''.join(texts)
|
| 131 |
+
else:
|
| 132 |
+
if contains_chinese(text):
|
| 133 |
+
text = self.zh_tn_model.normalize(text)
|
| 134 |
+
text = text.replace("\n", "")
|
| 135 |
+
text = replace_blank(text)
|
| 136 |
+
text = replace_corner_mark(text)
|
| 137 |
+
text = text.replace(".", "。")
|
| 138 |
+
text = text.replace(" - ", ",")
|
| 139 |
+
text = remove_bracket(text)
|
| 140 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 141 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 142 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 143 |
+
else:
|
| 144 |
+
text = self.en_tn_model.normalize(text)
|
| 145 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 146 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 147 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 148 |
+
texts = [i for i in texts if not is_only_punctuation(i)]
|
| 149 |
+
return texts if split is True else text
|
| 150 |
+
|
| 151 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 152 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 153 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 154 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 155 |
+
return model_input
|
| 156 |
+
|
| 157 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 158 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 159 |
+
if zero_shot_spk_id == '':
|
| 160 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 161 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 162 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 163 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 164 |
+
if resample_rate == 24000:
|
| 165 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 166 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 167 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 168 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 169 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 170 |
+
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 171 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 172 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 173 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 174 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 175 |
+
else:
|
| 176 |
+
model_input = self.spk2info[zero_shot_spk_id]
|
| 177 |
+
model_input['text'] = tts_text_token
|
| 178 |
+
model_input['text_len'] = tts_text_token_len
|
| 179 |
+
return model_input
|
| 180 |
+
|
| 181 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 182 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
| 183 |
+
# in cross lingual mode, we remove prompt in llm
|
| 184 |
+
del model_input['prompt_text']
|
| 185 |
+
del model_input['prompt_text_len']
|
| 186 |
+
del model_input['llm_prompt_speech_token']
|
| 187 |
+
del model_input['llm_prompt_speech_token_len']
|
| 188 |
+
return model_input
|
| 189 |
+
|
| 190 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 191 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 192 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 193 |
+
del model_input['llm_embedding']
|
| 194 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 195 |
+
model_input['prompt_text'] = instruct_text_token
|
| 196 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 197 |
+
return model_input
|
| 198 |
+
|
| 199 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
|
| 200 |
+
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
|
| 201 |
+
del model_input['llm_prompt_speech_token']
|
| 202 |
+
del model_input['llm_prompt_speech_token_len']
|
| 203 |
+
return model_input
|
| 204 |
+
|
| 205 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 206 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 207 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 208 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 209 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 210 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 211 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 212 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 213 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 214 |
+
'flow_embedding': embedding}
|
| 215 |
+
return model_input
|