ZiyueJiang's picture
first commit for huggingface space
593f3bc
raw
history blame
8.04 kB
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
from tts.modules.ar_dur.commons.layers import Embedding
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
from tts.modules.ar_dur.ar_dur_predictor import expand_states
from tts.modules.llm_dit.transformer import Transformer
from tts.modules.llm_dit.time_embedding import TimestepEmbedding
class Diffusion(nn.Module):
def __init__(self):
super().__init__()
# Hparams
# cond dim
self.local_cond_dim = 512
self.ctx_mask_dim = 16
self.in_channels = 32
self.out_channels = 32
# LLM
self.encoder_dim = 1024
self.encoder_n_layers = 24
self.encoder_n_heads = 16
self.max_seq_len = 16384
self.multiple_of = 256
self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
self.local_cond_project = nn.Linear(
self.out_channels + self.ctx_mask_dim, self.local_cond_dim)
self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)
self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
# The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
self.f5_time_embed = TimestepEmbedding(self.encoder_dim)
# text encoder
self.ph_encoder = RelTransformerEncoder(
302, self.encoder_dim, self.encoder_dim,
self.encoder_dim * 2, 4, 6,
3, 0.0, prenet=True, pre_ln=True)
self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
self.ph_pos_embed = PosEmb(self.encoder_dim)
self.ling_pre_net = torch.nn.Sequential(*[
torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
for i, s in enumerate([2, 2])
])
def forward(self, inputs, sigmas=None, x_noisy=None):
ctx_mask = inputs['ctx_mask']
ctx_feature = inputs['lat_ctx'] * ctx_mask
""" local conditioning (prompt_latent + spk_embed) """
ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
# ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
""" diffusion target latent """
x = inputs['lat']
# Here, x is x1 in CFM
x0 = torch.randn_like(x)
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
# define noisy_input and target
t = t.bfloat16()
x_noisy = (xt * (1 - ctx_mask)).bfloat16()
target = ut
# concat condition.
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
pred = self.postnet(encoder_out)
return pred, target
def forward_ling_encoder(self, txt_tokens, tone_tokens):
ph_tokens = txt_tokens
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
# enc_ph
ph_enc_oembed = self.tone_embed(tone_tokens)
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
return x_ling
def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
""" When we use torchdiffeq, we need to include the CFG process inside _forward() """
x = x * (1 - ctx_mask)
x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
pred = self.postnet(pred_v)
""" Perform multi-cond CFG """
cond_spk_txt, cond_txt, uncond = pred.chunk(3)
pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
return pred
@torch.no_grad()
def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
# txt embedding
x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)
# speaker embedding
ctx_feature = inputs['lat_ctx']
ctx_feature[1:, :, :] = 0 # prefix spk cfg
ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])
# local conditioning.
local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
local_cond = self.local_cond_project(local_cond)
''' Euler ODE solver '''
bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
# Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS),
# which is licensed under the MIT License.
sway_sampling_coef = -1.0
t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
if sway_sampling_coef is not None:
t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
# AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
def amo_sampling(z_t, t, t_next, v):
# Upcast to avoid precision issues when computing prev_sample
z_t = z_t.to(torch.float32)
# Constant definition in Algorithm 1
s = t_next
c = 3
# Line 7 in Algorithm 1
o = min(t_next + c * (t_next - t), 1)
pred_z_o = z_t + (o - t) * v
# Line 11 in Algorithm 1
a = s / o
b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
noise_i = torch.randn(size=z_t.shape, device=z_t.device)
z_t_next = a * pred_z_o + b * noise_i
return z_t_next.to(v.dtype)
x = torch.randn([1, frm_len, self.out_channels], device=device)
for step_index in range(timesteps):
x = x.to(torch.float32)
sigma = t_schedule[step_index].to(x_ling.dtype)
sigma_next = t_schedule[step_index + 1]
model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
x = amo_sampling(x, sigma, sigma_next, model_out)
# Cast sample back to model compatible dtype
x = x.to(model_out.dtype)
return x