Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 ByteDance and/or its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import argparse | |
import librosa | |
import numpy as np | |
import torch | |
from tn.chinese.normalizer import Normalizer as ZhNormalizer | |
from tn.english.normalizer import Normalizer as EnNormalizer | |
from langdetect import detect as classify_language | |
from pydub import AudioSegment | |
import pyloudnorm as pyln | |
from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator | |
from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit | |
from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments | |
from tts.utils.commons.ckpt_utils import load_ckpt | |
from tts.utils.commons.hparams import set_hparams, hparams | |
from tts.utils.text_utils.text_encoder import TokenTextEncoder | |
from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english | |
from tts.utils.commons.hparams import hparams, set_hparams | |
if "TOKENIZERS_PARALLELISM" not in os.environ: | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def convert_to_wav(wav_path): | |
# Check if the file exists | |
if not os.path.exists(wav_path): | |
print(f"The file '{wav_path}' does not exist.") | |
return | |
# Check if the file already has a .wav extension | |
if not wav_path.endswith(".wav"): | |
# Define the output path with a .wav extension | |
out_path = os.path.splitext(wav_path)[0] + ".wav" | |
# Load the audio file using pydub and convert it to WAV | |
audio = AudioSegment.from_file(wav_path) | |
audio.export(out_path, format="wav") | |
print(f"Converted '{wav_path}' to '{out_path}'") | |
def cut_wav(wav_path, max_len=28): | |
audio = AudioSegment.from_file(wav_path) | |
audio = audio[:int(max_len * 1000)] | |
audio.export(wav_path, format="wav") | |
class MegaTTS3DiTInfer(): | |
def __init__( | |
self, | |
device=None, | |
ckpt_root='./checkpoints', | |
dit_exp_name='diffusion_transformer', | |
frontend_exp_name='aligner_lm', | |
wavvae_exp_name='wavvae', | |
dur_ckpt_path='duration_lm', | |
g2p_exp_name='g2p', | |
precision=torch.float16, | |
**kwargs | |
): | |
self.sr = 24000 | |
self.fm = 8 | |
if device is None: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.device = device | |
self.precision = precision | |
# build models | |
self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name) | |
self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name) | |
self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name) | |
self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path) | |
self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name) | |
self.build_model(self.device) | |
# init text normalizer | |
self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False) | |
self.en_normalizer = EnNormalizer(overwrite_cache=False) | |
# loudness meter | |
self.loudness_meter = pyln.Meter(self.sr) | |
def build_model(self, device): | |
set_hparams(exp_name=self.dit_exp_name, print_hparams=False) | |
''' Load Dict ''' | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig')) | |
self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']} | |
self.token_encoder = token_encoder = self.ling_dict['phone'] | |
ph_dict_size = len(token_encoder) | |
''' Load Duration LM ''' | |
from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor | |
hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False) | |
hp_dur_model['frames_multiple'] = hparams['frames_multiple'] | |
self.dur_model = ARDurPredictor( | |
hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'], | |
hp_dur_model['dur_model_layers'], ph_dict_size, | |
hp_dur_model['dur_code_size'], | |
use_rot_embed=hp_dur_model.get('use_rot_embed', False)) | |
self.length_regulator = LengthRegulator() | |
load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model') | |
self.dur_model.eval() | |
self.dur_model.to(device) | |
''' Load Diffusion Transformer ''' | |
from tts.modules.llm_dit.dit import Diffusion | |
self.dit = Diffusion() | |
load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False) | |
self.dit.eval() | |
self.dit.to(device) | |
self.cfg_mask_token_phone = 302 - 1 | |
self.cfg_mask_token_tone = 32 - 1 | |
''' Load Frontend LM ''' | |
from tts.modules.aligner.whisper_small import Whisper | |
self.aligner_lm = Whisper() | |
load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model') | |
self.aligner_lm.eval() | |
self.aligner_lm.to(device) | |
self.kv_cache = None | |
self.hooks = None | |
''' Load G2P LM''' | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right") | |
g2p_tokenizer.padding_side = "right" | |
self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device) | |
self.g2p_tokenizer = g2p_tokenizer | |
self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0] | |
''' Wav VAE ''' | |
self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False) | |
from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3 | |
self.wavvae = WavVAE_V3(hparams=hp_wavvae) | |
if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'): | |
load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True) | |
self.has_vae_encoder = True | |
else: | |
load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False) | |
self.has_vae_encoder = False | |
self.wavvae.eval() | |
self.wavvae.to(device) | |
self.vae_stride = hp_wavvae.get('vae_stride', 4) | |
self.hop_size = hp_wavvae.get('hop_size', 4) | |
def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs): | |
wav_bytes = convert_to_wav_bytes(audio_bytes) | |
''' Load wav ''' | |
wav, _ = librosa.core.load(wav_bytes, sr=self.sr) | |
# Pad wav if necessary | |
ws = hparams['win_size'] | |
if len(wav) % ws < ws - 1: | |
wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32) | |
wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32) | |
self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float)) | |
''' obtain alignments with aligner_lm ''' | |
ph_ref, tone_ref, mel2ph_ref = align(self, wav) | |
with torch.inference_mode(): | |
''' Forward WaveVAE to obtain: prompt latent ''' | |
if self.has_vae_encoder: | |
wav = torch.FloatTensor(wav)[None].to(self.device) | |
vae_latent = self.wavvae.encode_latent(wav) | |
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4] | |
else: | |
assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode" | |
vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device) | |
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4] | |
''' Duration Prompting ''' | |
self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None | |
incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref) | |
return { | |
'ph_ref': ph_ref, | |
'tone_ref': tone_ref, | |
'mel2ph_ref': mel2ph_ref, | |
'vae_latent': vae_latent, | |
'incremental_state_dur_prompt': incremental_state_dur_prompt, | |
'ctx_dur_tokens': ctx_dur_tokens, | |
} | |
def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs): | |
device = self.device | |
ph_ref = resource_context['ph_ref'].to(device) | |
tone_ref = resource_context['tone_ref'].to(device) | |
mel2ph_ref = resource_context['mel2ph_ref'].to(device) | |
vae_latent = resource_context['vae_latent'].to(device) | |
ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device) | |
incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt'] | |
with torch.inference_mode(): | |
''' Generating ''' | |
wav_pred_ = [] | |
language_type = classify_language(input_text) | |
if language_type == 'en': | |
input_text = self.en_normalizer.normalize(input_text) | |
text_segs = chunk_text_english(input_text, max_chars=130) | |
else: | |
input_text = self.zh_normalizer.normalize(input_text) | |
text_segs = chunk_text_chinese(input_text, limit=60) | |
for seg_i, text in enumerate(text_segs): | |
''' G2P ''' | |
ph_pred, tone_pred = g2p(self, text) | |
''' Duration Prediction ''' | |
mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1) | |
inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent) | |
# Speech dit inference | |
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): | |
x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float() | |
# WavVAE decode | |
x[:, :vae_latent.size(1)] = vae_latent | |
wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32) | |
''' Post-processing ''' | |
# Trim prompt wav | |
wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy() | |
# Norm generated wav to prompt wav's level | |
meter = pyln.Meter(self.sr) # create BS.1770 meter | |
loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float)) | |
wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt) | |
if np.abs(wav_pred).max() >= 1: | |
wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95 | |
# Apply hamming window | |
wav_pred_.append(wav_pred) | |
wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float) | |
return to_wav_bytes(wav_pred, self.sr) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_wav', type=str) | |
parser.add_argument('--input_text', type=str) | |
parser.add_argument('--output_dir', type=str) | |
parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer') | |
parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight') | |
parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight') | |
args = parser.parse_args() | |
wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w | |
infer_ins = MegaTTS3DiTInfer() | |
with open(wav_path, 'rb') as file: | |
file_content = file.read() | |
print(f"| Start processing {wav_path}+{input_text}") | |
resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy')) | |
wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w) | |
print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav") | |
os.makedirs(out_path, exist_ok=True) | |
save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav') |