Gregniuki commited on
Commit
3983b32
·
verified ·
1 Parent(s): 5b1e4df

Delete model/modules.py

Browse files
Files changed (1) hide show
  1. model/modules.py +0 -658
model/modules.py DELETED
@@ -1,658 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import math
13
- from typing import Optional
14
-
15
- import torch
16
- import torch.nn.functional as F
17
- import torchaudio
18
- from librosa.filters import mel as librosa_mel_fn
19
- from torch import nn
20
- from x_transformers.x_transformers import apply_rotary_pos_emb
21
-
22
-
23
- # raw wav to mel spec
24
-
25
-
26
- mel_basis_cache = {}
27
- hann_window_cache = {}
28
-
29
-
30
- def get_bigvgan_mel_spectrogram(
31
- waveform,
32
- n_fft=1024,
33
- n_mel_channels=100,
34
- target_sample_rate=24000,
35
- hop_length=256,
36
- win_length=1024,
37
- fmin=0,
38
- fmax=None,
39
- center=False,
40
- ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
41
- device = waveform.device
42
- key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
43
-
44
- if key not in mel_basis_cache:
45
- mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
46
- mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
47
- hann_window_cache[key] = torch.hann_window(win_length).to(device)
48
-
49
- mel_basis = mel_basis_cache[key]
50
- hann_window = hann_window_cache[key]
51
-
52
- padding = (n_fft - hop_length) // 2
53
- waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
54
-
55
- spec = torch.stft(
56
- waveform,
57
- n_fft,
58
- hop_length=hop_length,
59
- win_length=win_length,
60
- window=hann_window,
61
- center=center,
62
- pad_mode="reflect",
63
- normalized=False,
64
- onesided=True,
65
- return_complex=True,
66
- )
67
- spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
68
-
69
- mel_spec = torch.matmul(mel_basis, spec)
70
- mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
71
-
72
- return mel_spec
73
-
74
-
75
- def get_vocos_mel_spectrogram(
76
- waveform,
77
- n_fft=1024,
78
- n_mel_channels=100,
79
- target_sample_rate=24000,
80
- hop_length=256,
81
- win_length=1024,
82
- ):
83
- mel_stft = torchaudio.transforms.MelSpectrogram(
84
- sample_rate=target_sample_rate,
85
- n_fft=n_fft,
86
- win_length=win_length,
87
- hop_length=hop_length,
88
- n_mels=n_mel_channels,
89
- power=1,
90
- center=True,
91
- normalized=False,
92
- norm=None,
93
- ).to(waveform.device)
94
- if len(waveform.shape) == 3:
95
- waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
96
-
97
- assert len(waveform.shape) == 2
98
-
99
- mel = mel_stft(waveform)
100
- mel = mel.clamp(min=1e-5).log()
101
- return mel
102
-
103
-
104
- class MelSpec(nn.Module):
105
- def __init__(
106
- self,
107
- n_fft=1024,
108
- hop_length=256,
109
- win_length=1024,
110
- n_mel_channels=100,
111
- target_sample_rate=24_000,
112
- mel_spec_type="vocos",
113
- ):
114
- super().__init__()
115
- assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
116
-
117
- self.n_fft = n_fft
118
- self.hop_length = hop_length
119
- self.win_length = win_length
120
- self.n_mel_channels = n_mel_channels
121
- self.target_sample_rate = target_sample_rate
122
-
123
- if mel_spec_type == "vocos":
124
- self.extractor = get_vocos_mel_spectrogram
125
- elif mel_spec_type == "bigvgan":
126
- self.extractor = get_bigvgan_mel_spectrogram
127
-
128
- self.register_buffer("dummy", torch.tensor(0), persistent=False)
129
-
130
- def forward(self, wav):
131
- if self.dummy.device != wav.device:
132
- self.to(wav.device)
133
-
134
- mel = self.extractor(
135
- waveform=wav,
136
- n_fft=self.n_fft,
137
- n_mel_channels=self.n_mel_channels,
138
- target_sample_rate=self.target_sample_rate,
139
- hop_length=self.hop_length,
140
- win_length=self.win_length,
141
- )
142
-
143
- return mel
144
-
145
-
146
- # sinusoidal position embedding
147
-
148
-
149
- class SinusPositionEmbedding(nn.Module):
150
- def __init__(self, dim):
151
- super().__init__()
152
- self.dim = dim
153
-
154
- def forward(self, x, scale=1000):
155
- device = x.device
156
- half_dim = self.dim // 2
157
- emb = math.log(10000) / (half_dim - 1)
158
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
159
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
160
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
161
- return emb
162
-
163
-
164
- # convolutional position embedding
165
-
166
-
167
- class ConvPositionEmbedding(nn.Module):
168
- def __init__(self, dim, kernel_size=31, groups=16):
169
- super().__init__()
170
- assert kernel_size % 2 != 0
171
- self.conv1d = nn.Sequential(
172
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
173
- nn.Mish(),
174
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
175
- nn.Mish(),
176
- )
177
-
178
- def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
179
- if mask is not None:
180
- mask = mask[..., None]
181
- x = x.masked_fill(~mask, 0.0)
182
-
183
- x = x.permute(0, 2, 1)
184
- x = self.conv1d(x)
185
- out = x.permute(0, 2, 1)
186
-
187
- if mask is not None:
188
- out = out.masked_fill(~mask, 0.0)
189
-
190
- return out
191
-
192
-
193
- # rotary positional embedding related
194
-
195
-
196
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
197
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
198
- # has some connection to NTK literature
199
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
200
- # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
201
- theta *= theta_rescale_factor ** (dim / (dim - 2))
202
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
203
- t = torch.arange(end, device=freqs.device) # type: ignore
204
- freqs = torch.outer(t, freqs).float() # type: ignore
205
- freqs_cos = torch.cos(freqs) # real part
206
- freqs_sin = torch.sin(freqs) # imaginary part
207
- return torch.cat([freqs_cos, freqs_sin], dim=-1)
208
-
209
-
210
- def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
- # length = length if isinstance(length, int) else length.max()
212
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
213
- pos = (
214
- start.unsqueeze(1)
215
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
216
- )
217
- # avoid extra long error.
218
- pos = torch.where(pos < max_pos, pos, max_pos - 1)
219
- return pos
220
-
221
-
222
- # Global Response Normalization layer (Instance Normalization ?)
223
-
224
-
225
- class GRN(nn.Module):
226
- def __init__(self, dim):
227
- super().__init__()
228
- self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
229
- self.beta = nn.Parameter(torch.zeros(1, 1, dim))
230
-
231
- def forward(self, x):
232
- Gx = torch.norm(x, p=2, dim=1, keepdim=True)
233
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
234
- return self.gamma * (x * Nx) + self.beta + x
235
-
236
-
237
- # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
238
- # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
239
-
240
-
241
- class ConvNeXtV2Block(nn.Module):
242
- def __init__(
243
- self,
244
- dim: int,
245
- intermediate_dim: int,
246
- dilation: int = 1,
247
- ):
248
- super().__init__()
249
- padding = (dilation * (7 - 1)) // 2
250
- self.dwconv = nn.Conv1d(
251
- dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
252
- ) # depthwise conv
253
- self.norm = nn.LayerNorm(dim, eps=1e-6)
254
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
255
- self.act = nn.GELU()
256
- self.grn = GRN(intermediate_dim)
257
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
258
-
259
- def forward(self, x: torch.Tensor) -> torch.Tensor:
260
- residual = x
261
- x = x.transpose(1, 2) # b n d -> b d n
262
- x = self.dwconv(x)
263
- x = x.transpose(1, 2) # b d n -> b n d
264
- x = self.norm(x)
265
- x = self.pwconv1(x)
266
- x = self.act(x)
267
- x = self.grn(x)
268
- x = self.pwconv2(x)
269
- return residual + x
270
-
271
-
272
- # AdaLayerNormZero
273
- # return with modulated x for attn input, and params for later mlp modulation
274
-
275
-
276
- class AdaLayerNormZero(nn.Module):
277
- def __init__(self, dim):
278
- super().__init__()
279
-
280
- self.silu = nn.SiLU()
281
- self.linear = nn.Linear(dim, dim * 6)
282
-
283
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
284
-
285
- def forward(self, x, emb=None):
286
- emb = self.linear(self.silu(emb))
287
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
288
-
289
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
290
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
291
-
292
-
293
- # AdaLayerNormZero for final layer
294
- # return only with modulated x for attn input, cuz no more mlp modulation
295
-
296
-
297
- class AdaLayerNormZero_Final(nn.Module):
298
- def __init__(self, dim):
299
- super().__init__()
300
-
301
- self.silu = nn.SiLU()
302
- self.linear = nn.Linear(dim, dim * 2)
303
-
304
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
305
-
306
- def forward(self, x, emb):
307
- emb = self.linear(self.silu(emb))
308
- scale, shift = torch.chunk(emb, 2, dim=1)
309
-
310
- x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
311
- return x
312
-
313
-
314
- # FeedForward
315
-
316
-
317
- class FeedForward(nn.Module):
318
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
319
- super().__init__()
320
- inner_dim = int(dim * mult)
321
- dim_out = dim_out if dim_out is not None else dim
322
-
323
- activation = nn.GELU(approximate=approximate)
324
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
325
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
326
-
327
- def forward(self, x):
328
- return self.ff(x)
329
-
330
-
331
- # Attention with possible joint part
332
- # modified from diffusers/src/diffusers/models/attention_processor.py
333
-
334
-
335
- class Attention(nn.Module):
336
- def __init__(
337
- self,
338
- processor: JointAttnProcessor | AttnProcessor,
339
- dim: int,
340
- heads: int = 8,
341
- dim_head: int = 64,
342
- dropout: float = 0.0,
343
- context_dim: Optional[int] = None, # if not None -> joint attention
344
- context_pre_only=None,
345
- ):
346
- super().__init__()
347
-
348
- if not hasattr(F, "scaled_dot_product_attention"):
349
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
350
-
351
- self.processor = processor
352
-
353
- self.dim = dim
354
- self.heads = heads
355
- self.inner_dim = dim_head * heads
356
- self.dropout = dropout
357
-
358
- self.context_dim = context_dim
359
- self.context_pre_only = context_pre_only
360
-
361
- self.to_q = nn.Linear(dim, self.inner_dim)
362
- self.to_k = nn.Linear(dim, self.inner_dim)
363
- self.to_v = nn.Linear(dim, self.inner_dim)
364
-
365
- if self.context_dim is not None:
366
- self.to_k_c = nn.Linear(context_dim, self.inner_dim)
367
- self.to_v_c = nn.Linear(context_dim, self.inner_dim)
368
- if self.context_pre_only is not None:
369
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
370
-
371
- self.to_out = nn.ModuleList([])
372
- self.to_out.append(nn.Linear(self.inner_dim, dim))
373
- self.to_out.append(nn.Dropout(dropout))
374
-
375
- if self.context_pre_only is not None and not self.context_pre_only:
376
- self.to_out_c = nn.Linear(self.inner_dim, dim)
377
-
378
- def forward(
379
- self,
380
- x: float["b n d"], # noised input x # noqa: F722
381
- c: float["b n d"] = None, # context c # noqa: F722
382
- mask: bool["b n"] | None = None, # noqa: F722
383
- rope=None, # rotary position embedding for x
384
- c_rope=None, # rotary position embedding for c
385
- ) -> torch.Tensor:
386
- if c is not None:
387
- return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
388
- else:
389
- return self.processor(self, x, mask=mask, rope=rope)
390
-
391
-
392
- # Attention processor
393
-
394
-
395
- class AttnProcessor:
396
- def __init__(self):
397
- pass
398
-
399
- def __call__(
400
- self,
401
- attn: Attention,
402
- x: float["b n d"], # noised input x # noqa: F722
403
- mask: bool["b n"] | None = None, # noqa: F722
404
- rope=None, # rotary position embedding
405
- ) -> torch.FloatTensor:
406
- batch_size = x.shape[0]
407
-
408
- # `sample` projections.
409
- query = attn.to_q(x)
410
- key = attn.to_k(x)
411
- value = attn.to_v(x)
412
-
413
- # apply rotary position embedding
414
- if rope is not None:
415
- freqs, xpos_scale = rope
416
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
417
-
418
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
419
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
420
-
421
- # attention
422
- inner_dim = key.shape[-1]
423
- head_dim = inner_dim // attn.heads
424
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
425
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
427
-
428
- # mask. e.g. inference got a batch with different target durations, mask out the padding
429
- if mask is not None:
430
- attn_mask = mask
431
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
432
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
433
- else:
434
- attn_mask = None
435
-
436
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
437
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
438
- x = x.to(query.dtype)
439
-
440
- # linear proj
441
- x = attn.to_out[0](x)
442
- # dropout
443
- x = attn.to_out[1](x)
444
-
445
- if mask is not None:
446
- mask = mask.unsqueeze(-1)
447
- x = x.masked_fill(~mask, 0.0)
448
-
449
- return x
450
-
451
-
452
- # Joint Attention processor for MM-DiT
453
- # modified from diffusers/src/diffusers/models/attention_processor.py
454
-
455
-
456
- class JointAttnProcessor:
457
- def __init__(self):
458
- pass
459
-
460
- def __call__(
461
- self,
462
- attn: Attention,
463
- x: float["b n d"], # noised input x # noqa: F722
464
- c: float["b nt d"] = None, # context c, here text # noqa: F722
465
- mask: bool["b n"] | None = None, # noqa: F722
466
- rope=None, # rotary position embedding for x
467
- c_rope=None, # rotary position embedding for c
468
- ) -> torch.FloatTensor:
469
- residual = x
470
-
471
- batch_size = c.shape[0]
472
-
473
- # `sample` projections.
474
- query = attn.to_q(x)
475
- key = attn.to_k(x)
476
- value = attn.to_v(x)
477
-
478
- # `context` projections.
479
- c_query = attn.to_q_c(c)
480
- c_key = attn.to_k_c(c)
481
- c_value = attn.to_v_c(c)
482
-
483
- # apply rope for context and noised input independently
484
- if rope is not None:
485
- freqs, xpos_scale = rope
486
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
487
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
488
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
489
- if c_rope is not None:
490
- freqs, xpos_scale = c_rope
491
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
492
- c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
493
- c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
494
-
495
- # attention
496
- query = torch.cat([query, c_query], dim=1)
497
- key = torch.cat([key, c_key], dim=1)
498
- value = torch.cat([value, c_value], dim=1)
499
-
500
- inner_dim = key.shape[-1]
501
- head_dim = inner_dim // attn.heads
502
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
503
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
504
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
505
-
506
- # mask. e.g. inference got a batch with different target durations, mask out the padding
507
- if mask is not None:
508
- attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
509
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
510
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
511
- else:
512
- attn_mask = None
513
-
514
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
515
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
516
- x = x.to(query.dtype)
517
-
518
- # Split the attention outputs.
519
- x, c = (
520
- x[:, : residual.shape[1]],
521
- x[:, residual.shape[1] :],
522
- )
523
-
524
- # linear proj
525
- x = attn.to_out[0](x)
526
- # dropout
527
- x = attn.to_out[1](x)
528
- if not attn.context_pre_only:
529
- c = attn.to_out_c(c)
530
-
531
- if mask is not None:
532
- mask = mask.unsqueeze(-1)
533
- x = x.masked_fill(~mask, 0.0)
534
- # c = c.masked_fill(~mask, 0.) # no mask for c (text)
535
-
536
- return x, c
537
-
538
-
539
- # DiT Block
540
-
541
-
542
- class DiTBlock(nn.Module):
543
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
544
- super().__init__()
545
-
546
- self.attn_norm = AdaLayerNormZero(dim)
547
- self.attn = Attention(
548
- processor=AttnProcessor(),
549
- dim=dim,
550
- heads=heads,
551
- dim_head=dim_head,
552
- dropout=dropout,
553
- )
554
-
555
- self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
556
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
557
-
558
- def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
559
- # pre-norm & modulation for attention input
560
- norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
561
-
562
- # attention
563
- attn_output = self.attn(x=norm, mask=mask, rope=rope)
564
-
565
- # process attention output for input x
566
- x = x + gate_msa.unsqueeze(1) * attn_output
567
-
568
- norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
569
- ff_output = self.ff(norm)
570
- x = x + gate_mlp.unsqueeze(1) * ff_output
571
-
572
- return x
573
-
574
-
575
- # MMDiT Block https://arxiv.org/abs/2403.03206
576
-
577
-
578
- class MMDiTBlock(nn.Module):
579
- r"""
580
- modified from diffusers/src/diffusers/models/attention.py
581
-
582
- notes.
583
- _c: context related. text, cond, etc. (left part in sd3 fig2.b)
584
- _x: noised input related. (right part)
585
- context_pre_only: last layer only do prenorm + modulation cuz no more ffn
586
- """
587
-
588
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
589
- super().__init__()
590
-
591
- self.context_pre_only = context_pre_only
592
-
593
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
594
- self.attn_norm_x = AdaLayerNormZero(dim)
595
- self.attn = Attention(
596
- processor=JointAttnProcessor(),
597
- dim=dim,
598
- heads=heads,
599
- dim_head=dim_head,
600
- dropout=dropout,
601
- context_dim=dim,
602
- context_pre_only=context_pre_only,
603
- )
604
-
605
- if not context_pre_only:
606
- self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
607
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
608
- else:
609
- self.ff_norm_c = None
610
- self.ff_c = None
611
- self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
612
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
613
-
614
- def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
615
- # pre-norm & modulation for attention input
616
- if self.context_pre_only:
617
- norm_c = self.attn_norm_c(c, t)
618
- else:
619
- norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
620
- norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
621
-
622
- # attention
623
- x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
624
-
625
- # process attention output for context c
626
- if self.context_pre_only:
627
- c = None
628
- else: # if not last layer
629
- c = c + c_gate_msa.unsqueeze(1) * c_attn_output
630
-
631
- norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
632
- c_ff_output = self.ff_c(norm_c)
633
- c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
634
-
635
- # process attention output for input x
636
- x = x + x_gate_msa.unsqueeze(1) * x_attn_output
637
-
638
- norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
639
- x_ff_output = self.ff_x(norm_x)
640
- x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
641
-
642
- return c, x
643
-
644
-
645
- # time step conditioning embedding
646
-
647
-
648
- class TimestepEmbedding(nn.Module):
649
- def __init__(self, dim, freq_embed_dim=256):
650
- super().__init__()
651
- self.time_embed = SinusPositionEmbedding(freq_embed_dim)
652
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
653
-
654
- def forward(self, timestep: float["b"]): # noqa: F821
655
- time_hidden = self.time_embed(timestep)
656
- time_hidden = time_hidden.to(timestep.dtype)
657
- time = self.time_mlp(time_hidden) # b d
658
- return time