Dakerqi commited on
Commit
9bd1d2d
·
verified ·
1 Parent(s): a3e0ef3

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -930
model.py DELETED
@@ -1,930 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # GLIDE: https://github.com/openai/glide-text2im
9
- # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
- # --------------------------------------------------------
11
-
12
- import math
13
- from typing import List, Optional, Tuple
14
-
15
- from flash_attn import flash_attn_varlen_func
16
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
-
21
- from .components import RMSNorm
22
-
23
-
24
- def modulate(x, scale):
25
- return x * (1 + scale.unsqueeze(1))
26
-
27
-
28
- #############################################################################
29
- # Embedding Layers for Timesteps and Class Labels #
30
- #############################################################################
31
-
32
-
33
- class TimestepEmbedder(nn.Module):
34
- """
35
- Embeds scalar timesteps into vector representations.
36
- """
37
-
38
- def __init__(self, hidden_size, frequency_embedding_size=256):
39
- super().__init__()
40
- self.mlp = nn.Sequential(
41
- nn.Linear(
42
- frequency_embedding_size,
43
- hidden_size,
44
- bias=True,
45
- ),
46
- nn.SiLU(),
47
- nn.Linear(
48
- hidden_size,
49
- hidden_size,
50
- bias=True,
51
- ),
52
- )
53
- nn.init.normal_(self.mlp[0].weight, std=0.02)
54
- nn.init.zeros_(self.mlp[0].bias)
55
- nn.init.normal_(self.mlp[2].weight, std=0.02)
56
- nn.init.zeros_(self.mlp[2].bias)
57
-
58
- self.frequency_embedding_size = frequency_embedding_size
59
-
60
- @staticmethod
61
- def timestep_embedding(t, dim, max_period=10000):
62
- """
63
- Create sinusoidal timestep embeddings.
64
- :param t: a 1-D Tensor of N indices, one per batch element.
65
- These may be fractional.
66
- :param dim: the dimension of the output.
67
- :param max_period: controls the minimum frequency of the embeddings.
68
- :return: an (N, D) Tensor of positional embeddings.
69
- """
70
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
71
- half = dim // 2
72
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
73
- device=t.device
74
- )
75
- args = t[:, None].float() * freqs[None]
76
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
77
- if dim % 2:
78
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
79
- return embedding
80
-
81
- def forward(self, t):
82
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
83
- t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
84
- return t_emb
85
-
86
-
87
- #############################################################################
88
- # Core NextDiT Model #
89
- #############################################################################
90
-
91
-
92
- class JointAttention(nn.Module):
93
- """Multi-head attention module."""
94
-
95
- def __init__(
96
- self,
97
- dim: int,
98
- n_heads: int,
99
- n_kv_heads: Optional[int],
100
- qk_norm: bool,
101
- ):
102
- """
103
- Initialize the Attention module.
104
-
105
- Args:
106
- dim (int): Number of input dimensions.
107
- n_heads (int): Number of heads.
108
- n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
109
-
110
- """
111
- super().__init__()
112
- self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
113
- self.n_local_heads = n_heads
114
- self.n_local_kv_heads = self.n_kv_heads
115
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
116
- self.head_dim = dim // n_heads
117
-
118
- self.qkv = nn.Linear(
119
- dim,
120
- (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
121
- bias=False,
122
- )
123
- nn.init.xavier_uniform_(self.qkv.weight)
124
-
125
- self.out = nn.Linear(
126
- n_heads * self.head_dim,
127
- dim,
128
- bias=False,
129
- )
130
- nn.init.xavier_uniform_(self.out.weight)
131
-
132
- if qk_norm:
133
- self.q_norm = RMSNorm(self.head_dim)
134
- self.k_norm = RMSNorm(self.head_dim)
135
- else:
136
- self.q_norm = self.k_norm = nn.Identity()
137
-
138
- @staticmethod
139
- def apply_rotary_emb(
140
- x_in: torch.Tensor,
141
- freqs_cis: torch.Tensor,
142
- ) -> torch.Tensor:
143
- """
144
- Apply rotary embeddings to input tensors using the given frequency
145
- tensor.
146
-
147
- This function applies rotary embeddings to the given query 'xq' and
148
- key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
149
- input tensors are reshaped as complex numbers, and the frequency tensor
150
- is reshaped for broadcasting compatibility. The resulting tensors
151
- contain rotary embeddings and are returned as real tensors.
152
-
153
- Args:
154
- x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
155
- freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
156
- exponentials.
157
-
158
- Returns:
159
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
160
- and key tensor with rotary embeddings.
161
- """
162
- with torch.cuda.amp.autocast(enabled=False):
163
- x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
164
- freqs_cis = freqs_cis.unsqueeze(2)
165
- x_out = torch.view_as_real(x * freqs_cis).flatten(3)
166
- return x_out.type_as(x_in)
167
-
168
- # copied from huggingface modeling_llama.py
169
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
170
- def _get_unpad_data(attention_mask):
171
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
172
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
173
- max_seqlen_in_batch = seqlens_in_batch.max().item()
174
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
175
- return (
176
- indices,
177
- cu_seqlens,
178
- max_seqlen_in_batch,
179
- )
180
-
181
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
182
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
183
-
184
- key_layer = index_first_axis(
185
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
186
- indices_k,
187
- )
188
- value_layer = index_first_axis(
189
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
190
- indices_k,
191
- )
192
- if query_length == kv_seq_len:
193
- query_layer = index_first_axis(
194
- query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
195
- indices_k,
196
- )
197
- cu_seqlens_q = cu_seqlens_k
198
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
199
- indices_q = indices_k
200
- elif query_length == 1:
201
- max_seqlen_in_batch_q = 1
202
- cu_seqlens_q = torch.arange(
203
- batch_size + 1, dtype=torch.int32, device=query_layer.device
204
- ) # There is a memcpy here, that is very bad.
205
- indices_q = cu_seqlens_q[:-1]
206
- query_layer = query_layer.squeeze(1)
207
- else:
208
- # The -q_len: slice assumes left padding.
209
- attention_mask = attention_mask[:, -query_length:]
210
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
211
-
212
- return (
213
- query_layer,
214
- key_layer,
215
- value_layer,
216
- indices_q,
217
- (cu_seqlens_q, cu_seqlens_k),
218
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
219
- )
220
-
221
- def forward(
222
- self,
223
- x: torch.Tensor,
224
- x_mask: torch.Tensor,
225
- freqs_cis: torch.Tensor,
226
- ) -> torch.Tensor:
227
- """
228
-
229
- Args:
230
- x:
231
- x_mask:
232
- freqs_cis:
233
-
234
- Returns:
235
-
236
- """
237
- bsz, seqlen, _ = x.shape
238
- dtype = x.dtype
239
-
240
- xq, xk, xv = torch.split(
241
- self.qkv(x),
242
- [
243
- self.n_local_heads * self.head_dim,
244
- self.n_local_kv_heads * self.head_dim,
245
- self.n_local_kv_heads * self.head_dim,
246
- ],
247
- dim=-1,
248
- )
249
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
250
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
251
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
252
- xq = self.q_norm(xq)
253
- xk = self.k_norm(xk)
254
- xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
255
- xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
256
- xq, xk = xq.to(dtype), xk.to(dtype)
257
-
258
- softmax_scale = math.sqrt(1 / self.head_dim)
259
-
260
- if dtype in [torch.float16, torch.bfloat16]:
261
- # begin var_len flash attn
262
- (
263
- query_states,
264
- key_states,
265
- value_states,
266
- indices_q,
267
- cu_seq_lens,
268
- max_seq_lens,
269
- ) = self._upad_input(xq, xk, xv, x_mask, seqlen)
270
-
271
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
272
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
273
-
274
- attn_output_unpad = flash_attn_varlen_func(
275
- query_states,
276
- key_states,
277
- value_states,
278
- cu_seqlens_q=cu_seqlens_q,
279
- cu_seqlens_k=cu_seqlens_k,
280
- max_seqlen_q=max_seqlen_in_batch_q,
281
- max_seqlen_k=max_seqlen_in_batch_k,
282
- dropout_p=0.0,
283
- causal=False,
284
- softmax_scale=softmax_scale,
285
- )
286
- output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
287
- # end var_len_flash_attn
288
-
289
- else:
290
- n_rep = self.n_local_heads // self.n_local_kv_heads
291
- if n_rep >= 1:
292
- xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
293
- xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
294
- output = (
295
- F.scaled_dot_product_attention(
296
- xq.permute(0, 2, 1, 3),
297
- xk.permute(0, 2, 1, 3),
298
- xv.permute(0, 2, 1, 3),
299
- attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
300
- scale=softmax_scale,
301
- )
302
- .permute(0, 2, 1, 3)
303
- .to(dtype)
304
- )
305
-
306
- output = output.flatten(-2)
307
-
308
- return self.out(output)
309
-
310
-
311
- class FeedForward(nn.Module):
312
- def __init__(
313
- self,
314
- dim: int,
315
- hidden_dim: int,
316
- multiple_of: int,
317
- ffn_dim_multiplier: Optional[float],
318
- ):
319
- """
320
- Initialize the FeedForward module.
321
-
322
- Args:
323
- dim (int): Input dimension.
324
- hidden_dim (int): Hidden dimension of the feedforward layer.
325
- multiple_of (int): Value to ensure hidden dimension is a multiple
326
- of this value.
327
- ffn_dim_multiplier (float, optional): Custom multiplier for hidden
328
- dimension. Defaults to None.
329
-
330
- """
331
- super().__init__()
332
- # custom dim factor multiplier
333
- if ffn_dim_multiplier is not None:
334
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
335
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
336
-
337
- self.w1 = nn.Linear(
338
- dim,
339
- hidden_dim,
340
- bias=False,
341
- )
342
- nn.init.xavier_uniform_(self.w1.weight)
343
- self.w2 = nn.Linear(
344
- hidden_dim,
345
- dim,
346
- bias=False,
347
- )
348
- nn.init.xavier_uniform_(self.w2.weight)
349
- self.w3 = nn.Linear(
350
- dim,
351
- hidden_dim,
352
- bias=False,
353
- )
354
- nn.init.xavier_uniform_(self.w3.weight)
355
-
356
- # @torch.compile
357
- def _forward_silu_gating(self, x1, x3):
358
- return F.silu(x1) * x3
359
-
360
- def forward(self, x):
361
- return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
362
-
363
-
364
- class JointTransformerBlock(nn.Module):
365
- def __init__(
366
- self,
367
- layer_id: int,
368
- dim: int,
369
- n_heads: int,
370
- n_kv_heads: int,
371
- multiple_of: int,
372
- ffn_dim_multiplier: float,
373
- norm_eps: float,
374
- qk_norm: bool,
375
- modulation=True
376
- ) -> None:
377
- """
378
- Initialize a TransformerBlock.
379
-
380
- Args:
381
- layer_id (int): Identifier for the layer.
382
- dim (int): Embedding dimension of the input features.
383
- n_heads (int): Number of attention heads.
384
- n_kv_heads (Optional[int]): Number of attention heads in key and
385
- value features (if using GQA), or set to None for the same as
386
- query.
387
- multiple_of (int):
388
- ffn_dim_multiplier (float):
389
- norm_eps (float):
390
-
391
- """
392
- super().__init__()
393
- self.dim = dim
394
- self.head_dim = dim // n_heads
395
- self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
396
- self.feed_forward = FeedForward(
397
- dim=dim,
398
- hidden_dim=4 * dim,
399
- multiple_of=multiple_of,
400
- ffn_dim_multiplier=ffn_dim_multiplier,
401
- )
402
- self.layer_id = layer_id
403
- self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
404
- self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
405
-
406
- self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
407
- self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
408
-
409
- self.modulation = modulation
410
- if modulation:
411
- self.adaLN_modulation = nn.Sequential(
412
- nn.SiLU(),
413
- nn.Linear(
414
- min(dim, 1024),
415
- 4 * dim,
416
- bias=True,
417
- ),
418
- )
419
- nn.init.zeros_(self.adaLN_modulation[1].weight)
420
- nn.init.zeros_(self.adaLN_modulation[1].bias)
421
-
422
- def forward(
423
- self,
424
- x: torch.Tensor,
425
- x_mask: torch.Tensor,
426
- freqs_cis: torch.Tensor,
427
- adaln_input: Optional[torch.Tensor]=None,
428
- ):
429
- """
430
- Perform a forward pass through the TransformerBlock.
431
-
432
- Args:
433
- x (torch.Tensor): Input tensor.
434
- freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
435
-
436
- Returns:
437
- torch.Tensor: Output tensor after applying attention and
438
- feedforward layers.
439
-
440
- """
441
- if self.modulation:
442
- assert adaln_input is not None
443
- scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
444
-
445
- x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
446
- self.attention(
447
- modulate(self.attention_norm1(x), scale_msa),
448
- x_mask,
449
- freqs_cis,
450
- )
451
- )
452
- x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
453
- self.feed_forward(
454
- modulate(self.ffn_norm1(x), scale_mlp),
455
- )
456
- )
457
- else:
458
- assert adaln_input is None
459
- x = x + self.attention_norm2(
460
- self.attention(
461
- self.attention_norm1(x),
462
- x_mask,
463
- freqs_cis,
464
- )
465
- )
466
- x = x + self.ffn_norm2(
467
- self.feed_forward(
468
- self.ffn_norm1(x),
469
- )
470
- )
471
- return x
472
-
473
-
474
- class FinalLayer(nn.Module):
475
- """
476
- The final layer of NextDiT.
477
- """
478
-
479
- def __init__(self, hidden_size, patch_size, out_channels):
480
- super().__init__()
481
- self.norm_final = nn.LayerNorm(
482
- hidden_size,
483
- elementwise_affine=False,
484
- eps=1e-6,
485
- )
486
- self.linear = nn.Linear(
487
- hidden_size,
488
- patch_size * patch_size * out_channels,
489
- bias=True,
490
- )
491
- nn.init.zeros_(self.linear.weight)
492
- nn.init.zeros_(self.linear.bias)
493
-
494
- self.adaLN_modulation = nn.Sequential(
495
- nn.SiLU(),
496
- nn.Linear(
497
- min(hidden_size, 1024),
498
- hidden_size,
499
- bias=True,
500
- ),
501
- )
502
- nn.init.zeros_(self.adaLN_modulation[1].weight)
503
- nn.init.zeros_(self.adaLN_modulation[1].bias)
504
-
505
- def forward(self, x, c):
506
- scale = self.adaLN_modulation(c)
507
- x = modulate(self.norm_final(x), scale)
508
- x = self.linear(x)
509
- return x
510
-
511
-
512
- class RopeEmbedder:
513
- def __init__(
514
- self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
515
- ):
516
- super().__init__()
517
- self.theta = theta
518
- self.axes_dims = axes_dims
519
- self.axes_lens = axes_lens
520
- self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
521
-
522
- def __call__(self, ids: torch.Tensor):
523
- self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
524
- result = []
525
- for i in range(len(self.axes_dims)):
526
- # import torch.distributed as dist
527
- # if not dist.is_initialized() or dist.get_rank() == 0:
528
- # import pdb
529
- # pdb.set_trace()
530
- index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
531
- result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
532
- return torch.cat(result, dim=-1)
533
-
534
-
535
- class NextDiT(nn.Module):
536
- """
537
- Diffusion model with a Transformer backbone.
538
- """
539
-
540
- def __init__(
541
- self,
542
- patch_size: int = 2,
543
- in_channels: int = 4,
544
- dim: int = 4096,
545
- n_layers: int = 32,
546
- n_refiner_layers: int = 2,
547
- n_heads: int = 32,
548
- n_kv_heads: Optional[int] = None,
549
- multiple_of: int = 256,
550
- ffn_dim_multiplier: Optional[float] = None,
551
- norm_eps: float = 1e-5,
552
- qk_norm: bool = False,
553
- cap_feat_dim: int = 5120,
554
- axes_dims: List[int] = (16, 56, 56),
555
- axes_lens: List[int] = (1, 512, 512),
556
- ) -> None:
557
- super().__init__()
558
- self.in_channels = in_channels
559
- self.out_channels = in_channels
560
- self.patch_size = patch_size
561
-
562
- self.x_embedder = nn.Linear(
563
- in_features=patch_size * patch_size * in_channels,
564
- out_features=dim,
565
- bias=True,
566
- )
567
- nn.init.xavier_uniform_(self.x_embedder.weight)
568
- nn.init.constant_(self.x_embedder.bias, 0.0)
569
-
570
- self.noise_refiner = nn.ModuleList(
571
- [
572
- JointTransformerBlock(
573
- layer_id,
574
- dim,
575
- n_heads,
576
- n_kv_heads,
577
- multiple_of,
578
- ffn_dim_multiplier,
579
- norm_eps,
580
- qk_norm,
581
- modulation=True,
582
- )
583
- for layer_id in range(n_refiner_layers)
584
- ]
585
- )
586
- self.context_refiner = nn.ModuleList(
587
- [
588
- JointTransformerBlock(
589
- layer_id,
590
- dim,
591
- n_heads,
592
- n_kv_heads,
593
- multiple_of,
594
- ffn_dim_multiplier,
595
- norm_eps,
596
- qk_norm,
597
- modulation=False,
598
- )
599
- for layer_id in range(n_refiner_layers)
600
- ]
601
- )
602
-
603
- self.t_embedder = TimestepEmbedder(min(dim, 1024))
604
- self.cap_embedder = nn.Sequential(
605
- RMSNorm(cap_feat_dim, eps=norm_eps),
606
- nn.Linear(
607
- cap_feat_dim,
608
- dim,
609
- bias=True,
610
- ),
611
- )
612
- nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
613
- # nn.init.zeros_(self.cap_embedder[1].weight)
614
- nn.init.zeros_(self.cap_embedder[1].bias)
615
-
616
- self.layers = nn.ModuleList(
617
- [
618
- JointTransformerBlock(
619
- layer_id,
620
- dim,
621
- n_heads,
622
- n_kv_heads,
623
- multiple_of,
624
- ffn_dim_multiplier,
625
- norm_eps,
626
- qk_norm,
627
- )
628
- for layer_id in range(n_layers)
629
- ]
630
- )
631
- self.norm_final = RMSNorm(dim, eps=norm_eps)
632
- self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
633
-
634
- assert (dim // n_heads) == sum(axes_dims)
635
- self.axes_dims = axes_dims
636
- self.axes_lens = axes_lens
637
- self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
638
- self.dim = dim
639
- self.n_heads = n_heads
640
-
641
- def unpatchify(
642
- self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
643
- ) -> List[torch.Tensor]:
644
- """
645
- x: (N, T, patch_size**2 * C)
646
- imgs: (N, H, W, C)
647
- """
648
- pH = pW = self.patch_size
649
- imgs = []
650
- for i in range(x.size(0)):
651
- H, W = img_size[i]
652
- begin = cap_size[i]
653
- end = begin + (H // pH) * (W // pW)
654
- imgs.append(
655
- x[i][begin:end]
656
- .view(H // pH, W // pW, pH, pW, self.out_channels)
657
- .permute(4, 0, 2, 1, 3)
658
- .flatten(3, 4)
659
- .flatten(1, 2)
660
- )
661
-
662
- if return_tensor:
663
- imgs = torch.stack(imgs, dim=0)
664
- return imgs
665
-
666
- def patchify_and_embed(
667
- self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor
668
- ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
669
- bsz = len(x)
670
- pH = pW = self.patch_size
671
- device = x[0].device
672
-
673
- l_effective_cap_len = cap_mask.sum(dim=1).tolist()
674
- img_sizes = [(img.size(1), img.size(2)) for img in x]
675
- l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
676
-
677
- max_seq_len = max(
678
- (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
679
- )
680
- max_cap_len = max(l_effective_cap_len)
681
- max_img_len = max(l_effective_img_len)
682
-
683
- position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
684
-
685
- for i in range(bsz):
686
- cap_len = l_effective_cap_len[i]
687
- img_len = l_effective_img_len[i]
688
- H, W = img_sizes[i]
689
- H_tokens, W_tokens = H // pH, W // pW
690
- assert H_tokens * W_tokens == img_len
691
-
692
- position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
693
- position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
694
- row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
695
- col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
696
- position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
697
- position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
698
-
699
- freqs_cis = self.rope_embedder(position_ids)
700
-
701
- # build freqs_cis for cap and image individually
702
- cap_freqs_cis_shape = list(freqs_cis.shape)
703
- # cap_freqs_cis_shape[1] = max_cap_len
704
- cap_freqs_cis_shape[1] = cap_feats.shape[1]
705
- cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
706
-
707
- img_freqs_cis_shape = list(freqs_cis.shape)
708
- img_freqs_cis_shape[1] = max_img_len
709
- img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
710
-
711
- for i in range(bsz):
712
- cap_len = l_effective_cap_len[i]
713
- img_len = l_effective_img_len[i]
714
- cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
715
- img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
716
-
717
- # refine context
718
- for layer in self.context_refiner:
719
- cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
720
-
721
- # refine image
722
- flat_x = []
723
- for i in range(bsz):
724
- img = x[i]
725
- C, H, W = img.size()
726
- img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
727
- flat_x.append(img)
728
- x = flat_x
729
- padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
730
- padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device)
731
- for i in range(bsz):
732
- padded_img_embed[i, :l_effective_img_len[i]] = x[i]
733
- padded_img_mask[i, :l_effective_img_len[i]] = True
734
-
735
- padded_img_embed = self.x_embedder(padded_img_embed)
736
- for layer in self.noise_refiner:
737
- padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
738
-
739
- mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
740
- padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
741
- for i in range(bsz):
742
- cap_len = l_effective_cap_len[i]
743
- img_len = l_effective_img_len[i]
744
-
745
- mask[i, :cap_len+img_len] = True
746
- padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
747
- padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
748
-
749
- return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
750
-
751
-
752
- def forward(self, x, t, cap_feats, cap_mask):
753
- """
754
- Forward pass of NextDiT.
755
- t: (N,) tensor of diffusion timesteps
756
- y: (N,) tensor of text tokens/features
757
- """
758
-
759
- # import torch.distributed as dist
760
- # if not dist.is_initialized() or dist.get_rank() == 0:
761
- # import pdb
762
- # pdb.set_trace()
763
- # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt")
764
- t = self.t_embedder(t) # (N, D)
765
- adaln_input = t
766
-
767
- cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
768
-
769
- x_is_tensor = isinstance(x, torch.Tensor)
770
- x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t)
771
- freqs_cis = freqs_cis.to(x.device)
772
-
773
- for layer in self.layers:
774
- x = layer(x, mask, freqs_cis, adaln_input)
775
-
776
- x = self.final_layer(x, adaln_input)
777
- x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
778
-
779
- return x
780
-
781
- def forward_with_cfg(
782
- self,
783
- x,
784
- t,
785
- cap_feats,
786
- cap_mask,
787
- cfg_scale,
788
- cfg_trunc=1,
789
- renorm_cfg=1
790
- ):
791
- """
792
- Forward pass of NextDiT, but also batches the unconditional forward pass
793
- for classifier-free guidance.
794
- """
795
- # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
796
- half = x[: len(x) // 2]
797
- if t[0] < cfg_trunc:
798
- combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
799
- model_out = self.forward(combined, t, cap_feats, cap_mask) # [2, 16, 128, 128]
800
- # For exact reproducibility reasons, we apply classifier-free guidance on only
801
- # three channels by default. The standard approach to cfg applies it to all channels.
802
- # This can be done by uncommenting the following line and commenting-out the line following that.
803
- eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
804
- cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
805
- half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
806
- if float(renorm_cfg) > 0.0:
807
- ori_pos_norm = torch.linalg.vector_norm(cond_eps
808
- , dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
809
- )
810
- max_new_norm = ori_pos_norm * float(renorm_cfg)
811
- new_pos_norm = torch.linalg.vector_norm(
812
- half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
813
- )
814
- if new_pos_norm >= max_new_norm:
815
- half_eps = half_eps * (max_new_norm / new_pos_norm)
816
- else:
817
- combined = half
818
- model_out = self.forward(combined, t[:len(x) // 2], cap_feats[:len(x) // 2], cap_mask[:len(x) // 2])
819
- eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
820
- half_eps = eps
821
-
822
- output = torch.cat([half_eps, half_eps], dim=0)
823
- return output
824
-
825
- @staticmethod
826
- def precompute_freqs_cis(
827
- dim: List[int],
828
- end: List[int],
829
- theta: float = 10000.0,
830
- ):
831
- """
832
- Precompute the frequency tensor for complex exponentials (cis) with
833
- given dimensions.
834
-
835
- This function calculates a frequency tensor with complex exponentials
836
- using the given dimension 'dim' and the end index 'end'. The 'theta'
837
- parameter scales the frequencies. The returned tensor contains complex
838
- values in complex64 data type.
839
-
840
- Args:
841
- dim (list): Dimension of the frequency tensor.
842
- end (list): End index for precomputing frequencies.
843
- theta (float, optional): Scaling factor for frequency computation.
844
- Defaults to 10000.0.
845
-
846
- Returns:
847
- torch.Tensor: Precomputed frequency tensor with complex
848
- exponentials.
849
- """
850
- freqs_cis = []
851
- for i, (d, e) in enumerate(zip(dim, end)):
852
- freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
853
- timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
854
- freqs = torch.outer(timestep, freqs).float()
855
- freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
856
- freqs_cis.append(freqs_cis_i)
857
-
858
- return freqs_cis
859
-
860
- def parameter_count(self) -> int:
861
- total_params = 0
862
-
863
- def _recursive_count_params(module):
864
- nonlocal total_params
865
- for param in module.parameters(recurse=False):
866
- total_params += param.numel()
867
- for submodule in module.children():
868
- _recursive_count_params(submodule)
869
-
870
- _recursive_count_params(self)
871
- return total_params
872
-
873
- def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
874
- return list(self.layers)
875
-
876
- def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
877
- return list(self.layers)
878
-
879
-
880
- #############################################################################
881
- # NextDiT Configs #
882
- #############################################################################
883
-
884
- def NextDiT_2B_GQA_patch2_Adaln_Refiner(**kwargs):
885
- return NextDiT(
886
- patch_size=2,
887
- dim=2304,
888
- n_layers=26,
889
- n_heads=24,
890
- n_kv_heads=8,
891
- axes_dims=[32, 32, 32],
892
- axes_lens=[300, 512, 512],
893
- **kwargs
894
- )
895
-
896
- def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs):
897
- return NextDiT(
898
- patch_size=2,
899
- dim=2592,
900
- n_layers=30,
901
- n_heads=24,
902
- n_kv_heads=8,
903
- axes_dims=[36, 36, 36],
904
- axes_lens=[300, 512, 512],
905
- **kwargs,
906
- )
907
-
908
- def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs):
909
- return NextDiT(
910
- patch_size=2,
911
- dim=2880,
912
- n_layers=32,
913
- n_heads=24,
914
- n_kv_heads=8,
915
- axes_dims=[40, 40, 40],
916
- axes_lens=[300, 512, 512],
917
- **kwargs,
918
- )
919
-
920
- def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
921
- return NextDiT(
922
- patch_size=2,
923
- dim=3840,
924
- n_layers=32,
925
- n_heads=32,
926
- n_kv_heads=8,
927
- axes_dims=[40, 40, 40],
928
- axes_lens=[300, 512, 512],
929
- **kwargs,
930
- )