Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,972 Bytes
4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNorm_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
# text embedding
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b nt -> b nt d
# sinus pos emb
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
batch_text_len = text.shape[1]
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
return text
# noised input & masked cond audio embedding
class AudioEmbedding(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(2 * in_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond:
cond = torch.zeros_like(cond)
x = torch.cat((x, cond), dim=-1)
x = self.linear(x)
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using MM-DiT blocks
class MMDiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
self.text_cond, self.text_uncond = None, None # text cache
self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
MMDiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
qk_norm=qk_norm,
)
for i in range(depth)
]
)
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in MMDiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, drop_text=True)
c = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, drop_text=False)
c = self.text_cond
else:
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
for block in self.transformer_blocks:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
x = self.norm_out(x, t)
output = self.proj_out(x)
return output
|