File size: 8,036 Bytes
593f3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from tts.modules.llm_dit.cfm import ConditionalFlowMatcher
from tts.modules.ar_dur.commons.layers import Embedding
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
from tts.modules.ar_dur.ar_dur_predictor import expand_states
from tts.modules.llm_dit.transformer import Transformer
from tts.modules.llm_dit.time_embedding import TimestepEmbedding


class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        # Hparams
        # cond dim
        self.local_cond_dim = 512
        self.ctx_mask_dim = 16
        self.in_channels = 32
        self.out_channels = 32
        # LLM
        self.encoder_dim = 1024
        self.encoder_n_layers = 24
        self.encoder_n_heads = 16
        self.max_seq_len = 16384
        self.multiple_of = 256

        self.ctx_mask_proj = nn.Linear(1, self.ctx_mask_dim)
        self.local_cond_project = nn.Linear(
            self.out_channels + self.ctx_mask_dim, self.local_cond_dim)

        self.encoder = Transformer(self.encoder_n_layers, self.encoder_dim, self.encoder_n_heads, self.max_seq_len)

        self.x_prenet = nn.Linear(self.in_channels, self.encoder_dim)
        self.prenet = nn.Linear(self.local_cond_dim, self.encoder_dim)
        self.postnet = nn.Linear(self.encoder_dim, self.out_channels)
  
        self.flow_matcher = ConditionalFlowMatcher(sigma=0.0)
        # The implementation of TimestepEmbedding is a modified version from F5-TTS (https://github.com/SWivid/F5-TTS), 
        # which is licensed under the MIT License.
        self.f5_time_embed = TimestepEmbedding(self.encoder_dim)

        # text encoder
        self.ph_encoder = RelTransformerEncoder(
            302, self.encoder_dim, self.encoder_dim,
            self.encoder_dim * 2, 4, 6,
            3, 0.0, prenet=True, pre_ln=True)
        self.tone_embed = Embedding(32, self.encoder_dim, padding_idx=0)
        self.ph_pos_embed = PosEmb(self.encoder_dim)
        self.ling_pre_net = torch.nn.Sequential(*[
            torch.nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=s * 2, stride=s, padding=s // 2)
            for i, s in enumerate([2, 2])
        ])
    
    def forward(self, inputs, sigmas=None, x_noisy=None):
        ctx_mask = inputs['ctx_mask']
        ctx_feature = inputs['lat_ctx'] * ctx_mask

        """ local conditioning (prompt_latent + spk_embed) """
        ctx_mask_emb = self.ctx_mask_proj(ctx_mask)
        # ctx_feature = ctx_feature * (1 - inputs["spk_cfg_mask"][:, :, None])
        local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
        local_cond = self.local_cond_project(local_cond)

        """ diffusion target latent """
        x = inputs['lat']
    
        # Here, x is x1 in CFM
        x0 = torch.randn_like(x)
        t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(x0, x)
        
        # define noisy_input and target
        t = t.bfloat16()
        x_noisy = (xt * (1 - ctx_mask)).bfloat16()
        target = ut

        # concat condition.
        x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
        x_ling = self.ling_pre_net(expand_states(x_ling, inputs['mel2ph']).transpose(1, 2)).transpose(1, 2)
        x_noisy = self.x_prenet(x_noisy) + self.prenet(local_cond) + x_ling
        encoder_out = self.encoder(x_noisy, self.f5_time_embed(t), attn_mask=inputs["text_mel_mask"], do_checkpoint=False)
        pred = self.postnet(encoder_out)

        return pred, target
    
    def forward_ling_encoder(self, txt_tokens, tone_tokens):
        ph_tokens = txt_tokens
        ph_nonpadding = (ph_tokens > 0).float()[:, :, None]  # [B, T_phone, 1]

        # enc_ph
        ph_enc_oembed = self.tone_embed(tone_tokens)
        ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
            torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
        ph_enc_oembed = ph_enc_oembed
        ph_enc_oembed = ph_enc_oembed * ph_nonpadding
        x_ling = self.ph_encoder(ph_tokens, other_embeds=ph_enc_oembed) * ph_nonpadding
        return x_ling

    def _forward(self, x, local_cond, x_ling, timesteps, ctx_mask, dur=None, seq_cfg_w=[1.0,1.0]):
        """ When we use torchdiffeq, we need to include the CFG process inside _forward() """
        x = x * (1 - ctx_mask)
        x = self.x_prenet(x) + self.prenet(local_cond) + x_ling
        pred_v = self.encoder(x, self.f5_time_embed(timesteps), attn_mask=torch.ones((x.size(0), x.size(1)), device=x.device))
        pred = self.postnet(pred_v)

        """ Perform multi-cond CFG """
        cond_spk_txt, cond_txt, uncond = pred.chunk(3)
        pred = uncond + seq_cfg_w[0] * (cond_txt - uncond) + seq_cfg_w[1] * (cond_spk_txt - cond_txt)
        return pred

    @torch.no_grad()
    def inference(self, inputs, timesteps=20, seq_cfg_w=[1.0, 1.0], **kwargs):
        # txt embedding
        x_ling = self.forward_ling_encoder(inputs["phone"], inputs["tone"])
        x_ling = self.ling_pre_net(expand_states(x_ling, inputs['dur']).transpose(1, 2)).transpose(1, 2)

        # speaker embedding
        ctx_feature = inputs['lat_ctx']
        ctx_feature[1:, :, :] = 0 # prefix spk cfg
        ctx_mask_emb = self.ctx_mask_proj(inputs['ctx_mask'])

        # local conditioning.
        local_cond = torch.cat([ctx_feature, ctx_mask_emb], dim=-1)
        local_cond = self.local_cond_project(local_cond)
        
        ''' Euler ODE solver '''
        bsz, device, frm_len = (local_cond.size(0), local_cond.device, local_cond.size(1))
        # Sway sampling from F5-TTS (https://github.com/SWivid/F5-TTS), 
        # which is licensed under the MIT License.
        sway_sampling_coef = -1.0
        t_schedule = torch.linspace(0, 1, timesteps + 1, device=device, dtype=x_ling.dtype)
        if sway_sampling_coef is not None:
            t_schedule = t_schedule + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_schedule) - 1 + t_schedule)
        
        # AMO sampling implementation for "AMO Sampler: Enhancing Text Rendering with Overshooting" (https://arxiv.org/pdf/2411.19415)
        def amo_sampling(z_t, t, t_next, v):
            # Upcast to avoid precision issues when computing prev_sample
            z_t = z_t.to(torch.float32)

            # Constant definition in Algorithm 1
            s = t_next
            c = 3

            # Line 7 in Algorithm 1
            o = min(t_next + c * (t_next - t), 1)
            pred_z_o = z_t + (o - t) * v

            # Line 11 in Algorithm 1
            a = s / o
            b = ((1 - s) ** 2 - (a * (1 - o)) ** 2) ** 0.5
            noise_i = torch.randn(size=z_t.shape, device=z_t.device)
            z_t_next = a * pred_z_o + b * noise_i
            return z_t_next.to(v.dtype)

        x = torch.randn([1, frm_len, self.out_channels], device=device)
        for step_index in range(timesteps):
            x = x.to(torch.float32)
            sigma = t_schedule[step_index].to(x_ling.dtype)
            sigma_next = t_schedule[step_index + 1]
            model_out = self._forward(torch.cat([x] * bsz), local_cond, x_ling, timesteps=sigma.unsqueeze(0), ctx_mask=inputs['ctx_mask'], dur=inputs['dur'], seq_cfg_w=seq_cfg_w)
            x = amo_sampling(x, sigma, sigma_next, model_out)
            # Cast sample back to model compatible dtype
            x = x.to(model_out.dtype)
        
        return x