Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 ByteDance and/or its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from torch import nn | |
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher | |
from tts.modules.ar_dur.commons.layers import Embedding | |
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb | |
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder | |
from tts.modules.ar_dur.ar_dur_predictor import expand_states | |
from tts.modules.llm_dit.transformer import Transformer | |
from tts.modules.llm_dit.time_embedding import TimestepEmbedding | |
class Diffusion(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# Hparams | |
# cond dim | |
self.local_cond_dim = 512 | |
self.ctx_mask_dim = 16 | |
self.in_channels = 32 | |
self.out_channels = 32 | |
# LLM | |
self.encoder_dim = 1024 | |
self.encoder_n_layers = 24 | |
self.encoder_n_heads = 16 | |
self.max_seq_len = 16384 | |
self.multiple_of = 256 | |
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim) | |
self.local_cond_project = nn.Linear( | |
self.out_channels + self.ctx_mask_dim, self.local_cond_dim) | |
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len) | |
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim) | |
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim) | |
self.postnet = nn.Linear(self.encoder_dim, self.out_channels) | |
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0) | |
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS), | |
# which is licensed under the MIT License. | |
self.f5_time_embed = TimestepEmbedding(self.encoder_dim) | |
# text encoder | |
self.ph_encoder = RelTransformerEncoder( | |
302, self.encoder_dim, self.encoder_dim, | |
self.encoder_dim * 2, 4, 6, | |
3, 0.0, prenet=True, pre_ln=True) | |
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0) | |
self.ph_pos_embed = PosEmb(self.encoder_dim) | |
self.ling_pre_net = torch.nn.Sequential(*[ | |
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2) | |
for i, s in enumerate([2, 2]) | |
]) | |
def forward(self, inputs, sigmas=None, x_noisy=None): | |
ctx_mask = inputs['ctx_mask'] | |
ctx_feature = inputs['lat_ctx'] * ctx_mask | |
""" local conditioning (prompt_latent + spk_embed) """ | |
ctx_mask_emb = self.ctx_mask_proj(ctx_mask) | |
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None]) | |
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) | |
local_cond = self.local_cond_project(local_cond) | |
""" diffusion target latent """ | |
x = inputs['lat'] | |
# Here, x is x1 in CFM | |
x0 = torch.randn_like(x) | |
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x) | |
# define noisy_input and target | |
t = t.bfloat16() | |
x_noisy = (xt * (1 - ctx_mask)).bfloat16() | |
target = ut | |
# concat condition. | |
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) | |
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2) | |
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling | |
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False) | |
pred = self.postnet(encoder_out) | |
return pred, target | |
def forward_ling_encoder(self, txt_tokens, tone_tokens): | |
ph_tokens = txt_tokens | |
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1] | |
# enc_ph | |
ph_enc_oembed = self.tone_embed(tone_tokens) | |
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( | |
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) | |
ph_enc_oembed = ph_enc_oembed | |
ph_enc_oembed = ph_enc_oembed * ph_nonpadding | |
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding | |
return x_ling | |
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]): | |
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """ | |
x = x * (1 - ctx_mask) | |
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling | |
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device)) | |
pred = self.postnet(pred_v) | |
""" Perform multi-cond CFG """ | |
cond_spk_txt, cond_txt, uncond = pred.chunk(3) | |
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt) | |
return pred | |
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs): | |
# txt embedding | |
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) | |
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2) | |
# speaker embedding | |
ctx_feature = inputs['lat_ctx'] | |
ctx_feature[1:, :, :] = 0 # prefix spk cfg | |
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask']) | |
# local conditioning. | |
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) | |
local_cond = self.local_cond_project(local_cond) | |
''' Euler ODE solver ''' | |
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1)) | |
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS), | |
# which is licensed under the MIT License. | |
sway_sampling_coef = -1.0 | |
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype) | |
if sway_sampling_coef is not None: | |
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule) | |
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415) | |
def amo_sampling(z_t, t, t_next, v): | |
# Upcast to avoid precision issues when computing prev_sample | |
z_t = z_t.to(torch.float32) | |
# Constant definition in Algorithm 1 | |
s = t_next | |
c = 3 | |
# Line 7 in Algorithm 1 | |
o = min(t_next + c * (t_next - t), 1) | |
pred_z_o = z_t + (o - t) * v | |
# Line 11 in Algorithm 1 | |
a = s / o | |
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5 | |
noise_i = torch.randn(size=z_t.shape, device=z_t.device) | |
z_t_next = a * pred_z_o + b * noise_i | |
return z_t_next.to(v.dtype) | |
x = torch.randn([1, frm_len, self.out_channels], device=device) | |
for step_index in range(timesteps): | |
x = x.to(torch.float32) | |
sigma = t_schedule[step_index].to(x_ling.dtype) | |
sigma_next = t_schedule[step_index + 1] | |
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w) | |
x = amo_sampling(x, sigma, sigma_next, model_out) | |
# Cast sample back to model compatible dtype | |
x = x.to(model_out.dtype) | |
return x | |