Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 ByteDance and/or its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from torch import nn | |
| from tts.modules.llm_dit.cfm import ConditionalFlowMatcher | |
| from tts.modules.ar_dur.commons.layers import Embedding | |
| from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb | |
| from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder | |
| from tts.modules.ar_dur.ar_dur_predictor import expand_states | |
| from tts.modules.llm_dit.transformer import Transformer | |
| from tts.modules.llm_dit.time_embedding import TimestepEmbedding | |
| class Diffusion(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Hparams | |
| # cond dim | |
| self.local_cond_dim = 512 | |
| self.ctx_mask_dim = 16 | |
| self.in_channels = 32 | |
| self.out_channels = 32 | |
| # LLM | |
| self.encoder_dim = 1024 | |
| self.encoder_n_layers = 24 | |
| self.encoder_n_heads = 16 | |
| self.max_seq_len = 16384 | |
| self.multiple_of = 256 | |
| self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim) | |
| self.local_cond_project = nn.Linear( | |
| self.out_channels + self.ctx_mask_dim, self.local_cond_dim) | |
| self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len) | |
| self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim) | |
| self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim) | |
| self.postnet = nn.Linear(self.encoder_dim, self.out_channels) | |
| self.flow_matcher = ConditionalFlowMatcher(sigma=0.0) | |
| # The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS), | |
| # which is licensed under the MIT License. | |
| self.f5_time_embed = TimestepEmbedding(self.encoder_dim) | |
| # text encoder | |
| self.ph_encoder = RelTransformerEncoder( | |
| 302, self.encoder_dim, self.encoder_dim, | |
| self.encoder_dim * 2, 4, 6, | |
| 3, 0.0, prenet=True, pre_ln=True) | |
| self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0) | |
| self.ph_pos_embed = PosEmb(self.encoder_dim) | |
| self.ling_pre_net = torch.nn.Sequential(*[ | |
| torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2) | |
| for i, s in enumerate([2, 2]) | |
| ]) | |
| def forward(self, inputs, sigmas=None, x_noisy=None): | |
| ctx_mask = inputs['ctx_mask'] | |
| ctx_feature = inputs['lat_ctx'] * ctx_mask | |
| """ local conditioning (prompt_latent + spk_embed) """ | |
| ctx_mask_emb = self.ctx_mask_proj(ctx_mask) | |
| # ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None]) | |
| local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) | |
| local_cond = self.local_cond_project(local_cond) | |
| """ diffusion target latent """ | |
| x = inputs['lat'] | |
| # Here, x is x1 in CFM | |
| x0 = torch.randn_like(x) | |
| t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x) | |
| # define noisy_input and target | |
| t = t.bfloat16() | |
| x_noisy = (xt * (1 - ctx_mask)).bfloat16() | |
| target = ut | |
| # concat condition. | |
| x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) | |
| x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2) | |
| x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling | |
| encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False) | |
| pred = self.postnet(encoder_out) | |
| return pred, target | |
| def forward_ling_encoder(self, txt_tokens, tone_tokens): | |
| ph_tokens = txt_tokens | |
| ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1] | |
| # enc_ph | |
| ph_enc_oembed = self.tone_embed(tone_tokens) | |
| ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( | |
| torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) | |
| ph_enc_oembed = ph_enc_oembed | |
| ph_enc_oembed = ph_enc_oembed * ph_nonpadding | |
| x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding | |
| return x_ling | |
| def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]): | |
| """ When we use torchdiffeq, we need to include the CFG process inside _forward() """ | |
| x = x * (1 - ctx_mask) | |
| x = self.x_prenet(x) + self.prenet(local_cond) + x_ling | |
| pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device)) | |
| pred = self.postnet(pred_v) | |
| """ Perform multi-cond CFG """ | |
| cond_spk_txt, cond_txt, uncond = pred.chunk(3) | |
| pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt) | |
| return pred | |
| def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs): | |
| # txt embedding | |
| x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"]) | |
| x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2) | |
| # speaker embedding | |
| ctx_feature = inputs['lat_ctx'] | |
| ctx_feature[1:, :, :] = 0 # prefix spk cfg | |
| ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask']) | |
| # local conditioning. | |
| local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1) | |
| local_cond = self.local_cond_project(local_cond) | |
| ''' Euler ODE solver ''' | |
| bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1)) | |
| # Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS), | |
| # which is licensed under the MIT License. | |
| sway_sampling_coef = -1.0 | |
| t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype) | |
| if sway_sampling_coef is not None: | |
| t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule) | |
| # AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415) | |
| def amo_sampling(z_t, t, t_next, v): | |
| # Upcast to avoid precision issues when computing prev_sample | |
| z_t = z_t.to(torch.float32) | |
| # Constant definition in Algorithm 1 | |
| s = t_next | |
| c = 3 | |
| # Line 7 in Algorithm 1 | |
| o = min(t_next + c * (t_next - t), 1) | |
| pred_z_o = z_t + (o - t) * v | |
| # Line 11 in Algorithm 1 | |
| a = s / o | |
| b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5 | |
| noise_i = torch.randn(size=z_t.shape, device=z_t.device) | |
| z_t_next = a * pred_z_o + b * noise_i | |
| return z_t_next.to(v.dtype) | |
| x = torch.randn([1, frm_len, self.out_channels], device=device) | |
| for step_index in range(timesteps): | |
| x = x.to(torch.float32) | |
| sigma = t_schedule[step_index].to(x_ling.dtype) | |
| sigma_next = t_schedule[step_index + 1] | |
| model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w) | |
| x = amo_sampling(x, sigma, sigma_next, model_out) | |
| # Cast sample back to model compatible dtype | |
| x = x.to(model_out.dtype) | |
| return x | |