Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import json | |
| import torchaudio | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation.logits_process import LogitsProcessor | |
| from transformers.generation.utils import LogitsProcessorList | |
| from cosyvoice.cli.cosyvoice import CosyVoice | |
| class RepetitionAwareLogitsProcessor(LogitsProcessor): | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| window_size = 10 | |
| threshold = 0.1 | |
| window = input_ids[:, -window_size:] | |
| if window.shape[1] < window_size: | |
| return scores | |
| last_tokens = window[:, -1].unsqueeze(-1) | |
| repeat_counts = (window == last_tokens).sum(dim=1) | |
| repeat_ratios = repeat_counts.float() / window_size | |
| mask = repeat_ratios > threshold | |
| scores[mask, last_tokens[mask].squeeze(-1)] = float("-inf") | |
| return scores | |
| class StepAudioTTS: | |
| def __init__( | |
| self, | |
| model_path, | |
| encoder, | |
| ): | |
| self.llm = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| trust_remote_code=True, | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, trust_remote_code=True | |
| ) | |
| self.common_cosy_model = CosyVoice( | |
| os.path.join(model_path, "CosyVoice-300M-25Hz") | |
| ) | |
| self.music_cosy_model = CosyVoice( | |
| os.path.join(model_path, "CosyVoice-300M-25Hz-Music") | |
| ) | |
| self.encoder = encoder | |
| self.sys_prompt_dict = { | |
| "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。", | |
| "sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。", | |
| "sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', | |
| "sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', | |
| } | |
| self.register_speakers() | |
| def __call__(self, text: str, prompt_speaker: str, clone_dict: dict | None = None): | |
| if clone_dict: | |
| clone_prompt_code, clone_prompt_token, clone_prompt_token_len, clone_speech_feat, clone_speech_feat_len, clone_speech_embedding = ( | |
| self.preprocess_prompt_wav(clone_dict['wav_path']) | |
| ) | |
| prompt_speaker = clone_dict['speaker'] | |
| self.speakers_info[prompt_speaker] = { | |
| "prompt_text": clone_dict['prompt_text'], | |
| "prompt_code": clone_prompt_code, | |
| "cosy_speech_feat": clone_speech_feat.to(torch.bfloat16), | |
| "cosy_speech_feat_len": clone_speech_feat_len, | |
| "cosy_speech_embedding": clone_speech_embedding.to(torch.bfloat16), | |
| "cosy_prompt_token": clone_prompt_token, | |
| "cosy_prompt_token_len": clone_prompt_token_len, | |
| } | |
| instruction_name = self.detect_instruction_name(text) | |
| if instruction_name in ("RAP", "哼唱"): | |
| prompt_speaker_info = self.speakers_info[ | |
| f"{prompt_speaker}{instruction_name}" | |
| ] | |
| cosy_model = self.music_cosy_model | |
| else: | |
| prompt_speaker_info = self.speakers_info[prompt_speaker] | |
| cosy_model = self.common_cosy_model | |
| if clone_dict: | |
| prompt_speaker = '' | |
| token_ids = self.tokenize( | |
| text, | |
| prompt_speaker_info["prompt_text"], | |
| prompt_speaker, | |
| prompt_speaker_info["prompt_code"], | |
| ) | |
| output_ids = self.llm.generate( | |
| torch.tensor([token_ids]).to(torch.long).to("cuda"), | |
| max_length=8192, | |
| temperature=0.7, | |
| do_sample=True, | |
| logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]), | |
| ) | |
| output_ids = output_ids[:, len(token_ids) : -1] # skip eos token | |
| return ( | |
| cosy_model.token_to_wav_offline( | |
| output_ids - 65536, | |
| prompt_speaker_info["cosy_speech_feat"].to(torch.bfloat16), | |
| prompt_speaker_info["cosy_speech_feat_len"], | |
| prompt_speaker_info["cosy_prompt_token"], | |
| prompt_speaker_info["cosy_prompt_token_len"], | |
| prompt_speaker_info["cosy_speech_embedding"].to(torch.bfloat16), | |
| ), | |
| 22050, | |
| ) | |
| def register_speakers(self): | |
| self.speakers_info = {} | |
| with open("speakers/speakers_info.json", "r") as f: | |
| speakers_info = json.load(f) | |
| for speaker_id, prompt_text in speakers_info.items(): | |
| prompt_wav_path = f"speakers/{speaker_id}_prompt.wav" | |
| prompt_code, prompt_token, prompt_token_len, speech_feat, speech_feat_len, speech_embedding = ( | |
| self.preprocess_prompt_wav(prompt_wav_path) | |
| ) | |
| self.speakers_info[speaker_id] = { | |
| "prompt_text": prompt_text, | |
| "prompt_code": prompt_code, | |
| "cosy_speech_feat": speech_feat.to(torch.bfloat16), | |
| "cosy_speech_feat_len": speech_feat_len, | |
| "cosy_speech_embedding": speech_embedding.to(torch.bfloat16), | |
| "cosy_prompt_token": prompt_token, | |
| "cosy_prompt_token_len": prompt_token_len, | |
| } | |
| print(f"Registered speaker: {speaker_id}") | |
| def detect_instruction_name(self, text): | |
| instruction_name = "" | |
| match_group = re.match(r"^([(\(][^\(\)()]*[)\)]).*$", text, re.DOTALL) | |
| if match_group is not None: | |
| instruction = match_group.group(1) | |
| instruction_name = instruction.strip("()()") | |
| return instruction_name | |
| def tokenize( | |
| self, text: str, prompt_text: str, prompt_speaker: str, prompt_code: list | |
| ): | |
| rap_or_vocal = self.detect_instruction_name(text) in ("RAP", "哼唱") | |
| if rap_or_vocal: | |
| if "哼唱" in text: | |
| prompt = self.sys_prompt_dict["sys_prompt_for_vocal"] | |
| else: | |
| prompt = self.sys_prompt_dict["sys_prompt_for_rap"] | |
| elif prompt_speaker: | |
| prompt = self.sys_prompt_dict["sys_prompt_with_spk"].format(prompt_speaker) | |
| else: | |
| prompt = self.sys_prompt_dict["sys_prompt_wo_spk"] | |
| sys_tokens = self.tokenizer.encode(f"system\n{prompt}") | |
| history = [1] | |
| history.extend([4] + sys_tokens + [3]) | |
| _prefix_tokens = self.tokenizer.encode("\n") | |
| prompt_token_encode = self.tokenizer.encode("\n" + prompt_text) | |
| prompt_tokens = prompt_token_encode[len(_prefix_tokens) :] | |
| target_token_encode = self.tokenizer.encode("\n" + text) | |
| target_tokens = target_token_encode[len(_prefix_tokens) :] | |
| qrole_toks = self.tokenizer.encode("human\n") | |
| arole_toks = self.tokenizer.encode("assistant\n") | |
| history.extend( | |
| [4] | |
| + qrole_toks | |
| + prompt_tokens | |
| + [3] | |
| + [4] | |
| + arole_toks | |
| + prompt_code | |
| + [3] | |
| + [4] | |
| + qrole_toks | |
| + target_tokens | |
| + [3] | |
| + [4] | |
| + arole_toks | |
| ) | |
| return history | |
| def preprocess_prompt_wav(self, prompt_wav_path : str): | |
| prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) | |
| prompt_wav_16k = torchaudio.transforms.Resample( | |
| orig_freq=prompt_wav_sr, new_freq=16000 | |
| )(prompt_wav) | |
| prompt_wav_22k = torchaudio.transforms.Resample( | |
| orig_freq=prompt_wav_sr, new_freq=22050 | |
| )(prompt_wav) | |
| speech_feat, speech_feat_len = ( | |
| self.common_cosy_model.frontend._extract_speech_feat(prompt_wav_22k) | |
| ) | |
| speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding( | |
| prompt_wav_16k | |
| ) | |
| prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) | |
| prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536 | |
| prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.long) | |
| return ( | |
| prompt_code, | |
| prompt_token, | |
| prompt_token_len, | |
| speech_feat, | |
| speech_feat_len, | |
| speech_embedding, | |
| ) |