JackyZhuo commited on
Commit
49a119b
·
verified ·
1 Parent(s): 9b1a463

Upload folder using huggingface_hub

Browse files
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import NextDiT_2B_GQA_patch2_Adaln_Refiner, NextDiT_3B_GQA_patch2_Adaln_Refiner, NextDiT_4B_GQA_patch2_Adaln_Refiner, NextDiT_7B_GQA_patch2_Adaln_Refiner
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (361 Bytes). View file
 
models/__pycache__/components.cpython-310.pyc ADDED
Binary file (2.16 kB). View file
 
models/__pycache__/model.cpython-310.pyc ADDED
Binary file (23.3 kB). View file
 
models/components.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ try:
7
+ from apex.normalization import FusedRMSNorm as RMSNorm
8
+ except ImportError:
9
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10
+
11
+ class RMSNorm(torch.nn.Module):
12
+ def __init__(self, dim: int, eps: float = 1e-6):
13
+ """
14
+ Initialize the RMSNorm normalization layer.
15
+
16
+ Args:
17
+ dim (int): The dimension of the input tensor.
18
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19
+
20
+ Attributes:
21
+ eps (float): A small value added to the denominator for numerical stability.
22
+ weight (nn.Parameter): Learnable scaling parameter.
23
+
24
+ """
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+
29
+ def _norm(self, x):
30
+ """
31
+ Apply the RMSNorm normalization to the input tensor.
32
+
33
+ Args:
34
+ x (torch.Tensor): The input tensor.
35
+
36
+ Returns:
37
+ torch.Tensor: The normalized tensor.
38
+
39
+ """
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x):
43
+ """
44
+ Forward pass through the RMSNorm layer.
45
+
46
+ Args:
47
+ x (torch.Tensor): The input tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: The output tensor after applying RMSNorm.
51
+
52
+ """
53
+ output = self._norm(x.float()).type_as(x)
54
+ return output * self.weight
models/model.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import math
13
+ from typing import List, Optional, Tuple
14
+
15
+ from flash_attn import flash_attn_varlen_func
16
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from .components import RMSNorm
22
+
23
+
24
+ def modulate(x, scale):
25
+ return x * (1 + scale.unsqueeze(1))
26
+
27
+
28
+ #############################################################################
29
+ # Embedding Layers for Timesteps and Class Labels #
30
+ #############################################################################
31
+
32
+
33
+ class TimestepEmbedder(nn.Module):
34
+ """
35
+ Embeds scalar timesteps into vector representations.
36
+ """
37
+
38
+ def __init__(self, hidden_size, frequency_embedding_size=256):
39
+ super().__init__()
40
+ self.mlp = nn.Sequential(
41
+ nn.Linear(
42
+ frequency_embedding_size,
43
+ hidden_size,
44
+ bias=True,
45
+ ),
46
+ nn.SiLU(),
47
+ nn.Linear(
48
+ hidden_size,
49
+ hidden_size,
50
+ bias=True,
51
+ ),
52
+ )
53
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
54
+ nn.init.zeros_(self.mlp[0].bias)
55
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
56
+ nn.init.zeros_(self.mlp[2].bias)
57
+
58
+ self.frequency_embedding_size = frequency_embedding_size
59
+
60
+ @staticmethod
61
+ def timestep_embedding(t, dim, max_period=10000):
62
+ """
63
+ Create sinusoidal timestep embeddings.
64
+ :param t: a 1-D Tensor of N indices, one per batch element.
65
+ These may be fractional.
66
+ :param dim: the dimension of the output.
67
+ :param max_period: controls the minimum frequency of the embeddings.
68
+ :return: an (N, D) Tensor of positional embeddings.
69
+ """
70
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
71
+ half = dim // 2
72
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
73
+ device=t.device
74
+ )
75
+ args = t[:, None].float() * freqs[None]
76
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
77
+ if dim % 2:
78
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
79
+ return embedding
80
+
81
+ def forward(self, t):
82
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
83
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
84
+ return t_emb
85
+
86
+
87
+ #############################################################################
88
+ # Core NextDiT Model #
89
+ #############################################################################
90
+
91
+
92
+ class JointAttention(nn.Module):
93
+ """Multi-head attention module."""
94
+
95
+ def __init__(
96
+ self,
97
+ dim: int,
98
+ n_heads: int,
99
+ n_kv_heads: Optional[int],
100
+ qk_norm: bool,
101
+ ):
102
+ """
103
+ Initialize the Attention module.
104
+
105
+ Args:
106
+ dim (int): Number of input dimensions.
107
+ n_heads (int): Number of heads.
108
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
109
+
110
+ """
111
+ super().__init__()
112
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
113
+ self.n_local_heads = n_heads
114
+ self.n_local_kv_heads = self.n_kv_heads
115
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
116
+ self.head_dim = dim // n_heads
117
+
118
+ self.qkv = nn.Linear(
119
+ dim,
120
+ (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
121
+ bias=False,
122
+ )
123
+ nn.init.xavier_uniform_(self.qkv.weight)
124
+
125
+ self.out = nn.Linear(
126
+ n_heads * self.head_dim,
127
+ dim,
128
+ bias=False,
129
+ )
130
+ nn.init.xavier_uniform_(self.out.weight)
131
+
132
+ if qk_norm:
133
+ self.q_norm = RMSNorm(self.head_dim)
134
+ self.k_norm = RMSNorm(self.head_dim)
135
+ else:
136
+ self.q_norm = self.k_norm = nn.Identity()
137
+
138
+ @staticmethod
139
+ def apply_rotary_emb(
140
+ x_in: torch.Tensor,
141
+ freqs_cis: torch.Tensor,
142
+ ) -> torch.Tensor:
143
+ """
144
+ Apply rotary embeddings to input tensors using the given frequency
145
+ tensor.
146
+
147
+ This function applies rotary embeddings to the given query 'xq' and
148
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
149
+ input tensors are reshaped as complex numbers, and the frequency tensor
150
+ is reshaped for broadcasting compatibility. The resulting tensors
151
+ contain rotary embeddings and are returned as real tensors.
152
+
153
+ Args:
154
+ x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
155
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
156
+ exponentials.
157
+
158
+ Returns:
159
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
160
+ and key tensor with rotary embeddings.
161
+ """
162
+ with torch.cuda.amp.autocast(enabled=False):
163
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
164
+ freqs_cis = freqs_cis.unsqueeze(2)
165
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
166
+ return x_out.type_as(x_in)
167
+
168
+ # copied from huggingface modeling_llama.py
169
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
170
+ def _get_unpad_data(attention_mask):
171
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
172
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
173
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
174
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
175
+ return (
176
+ indices,
177
+ cu_seqlens,
178
+ max_seqlen_in_batch,
179
+ )
180
+
181
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
182
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
183
+
184
+ key_layer = index_first_axis(
185
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
186
+ indices_k,
187
+ )
188
+ value_layer = index_first_axis(
189
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
190
+ indices_k,
191
+ )
192
+ if query_length == kv_seq_len:
193
+ query_layer = index_first_axis(
194
+ query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
195
+ indices_k,
196
+ )
197
+ cu_seqlens_q = cu_seqlens_k
198
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
199
+ indices_q = indices_k
200
+ elif query_length == 1:
201
+ max_seqlen_in_batch_q = 1
202
+ cu_seqlens_q = torch.arange(
203
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
204
+ ) # There is a memcpy here, that is very bad.
205
+ indices_q = cu_seqlens_q[:-1]
206
+ query_layer = query_layer.squeeze(1)
207
+ else:
208
+ # The -q_len: slice assumes left padding.
209
+ attention_mask = attention_mask[:, -query_length:]
210
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
211
+
212
+ return (
213
+ query_layer,
214
+ key_layer,
215
+ value_layer,
216
+ indices_q,
217
+ (cu_seqlens_q, cu_seqlens_k),
218
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
219
+ )
220
+
221
+ def forward(
222
+ self,
223
+ x: torch.Tensor,
224
+ x_mask: torch.Tensor,
225
+ freqs_cis: torch.Tensor,
226
+ ) -> torch.Tensor:
227
+ """
228
+
229
+ Args:
230
+ x:
231
+ x_mask:
232
+ freqs_cis:
233
+
234
+ Returns:
235
+
236
+ """
237
+ bsz, seqlen, _ = x.shape
238
+ dtype = x.dtype
239
+
240
+ xq, xk, xv = torch.split(
241
+ self.qkv(x),
242
+ [
243
+ self.n_local_heads * self.head_dim,
244
+ self.n_local_kv_heads * self.head_dim,
245
+ self.n_local_kv_heads * self.head_dim,
246
+ ],
247
+ dim=-1,
248
+ )
249
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
250
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
251
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
252
+ xq = self.q_norm(xq)
253
+ xk = self.k_norm(xk)
254
+ xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
255
+ xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
256
+ xq, xk = xq.to(dtype), xk.to(dtype)
257
+
258
+ softmax_scale = math.sqrt(1 / self.head_dim)
259
+
260
+ if dtype in [torch.float16, torch.bfloat16]:
261
+ # begin var_len flash attn
262
+ (
263
+ query_states,
264
+ key_states,
265
+ value_states,
266
+ indices_q,
267
+ cu_seq_lens,
268
+ max_seq_lens,
269
+ ) = self._upad_input(xq, xk, xv, x_mask, seqlen)
270
+
271
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
272
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
273
+
274
+ attn_output_unpad = flash_attn_varlen_func(
275
+ query_states,
276
+ key_states,
277
+ value_states,
278
+ cu_seqlens_q=cu_seqlens_q,
279
+ cu_seqlens_k=cu_seqlens_k,
280
+ max_seqlen_q=max_seqlen_in_batch_q,
281
+ max_seqlen_k=max_seqlen_in_batch_k,
282
+ dropout_p=0.0,
283
+ causal=False,
284
+ softmax_scale=softmax_scale,
285
+ )
286
+ output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
287
+ # end var_len_flash_attn
288
+
289
+ else:
290
+ n_rep = self.n_local_heads // self.n_local_kv_heads
291
+ if n_rep >= 1:
292
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
293
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
294
+ output = (
295
+ F.scaled_dot_product_attention(
296
+ xq.permute(0, 2, 1, 3),
297
+ xk.permute(0, 2, 1, 3),
298
+ xv.permute(0, 2, 1, 3),
299
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
300
+ scale=softmax_scale,
301
+ )
302
+ .permute(0, 2, 1, 3)
303
+ .to(dtype)
304
+ )
305
+
306
+ output = output.flatten(-2)
307
+
308
+ return self.out(output)
309
+
310
+
311
+ class FeedForward(nn.Module):
312
+ def __init__(
313
+ self,
314
+ dim: int,
315
+ hidden_dim: int,
316
+ multiple_of: int,
317
+ ffn_dim_multiplier: Optional[float],
318
+ ):
319
+ """
320
+ Initialize the FeedForward module.
321
+
322
+ Args:
323
+ dim (int): Input dimension.
324
+ hidden_dim (int): Hidden dimension of the feedforward layer.
325
+ multiple_of (int): Value to ensure hidden dimension is a multiple
326
+ of this value.
327
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
328
+ dimension. Defaults to None.
329
+
330
+ """
331
+ super().__init__()
332
+ # custom dim factor multiplier
333
+ if ffn_dim_multiplier is not None:
334
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
335
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
336
+
337
+ self.w1 = nn.Linear(
338
+ dim,
339
+ hidden_dim,
340
+ bias=False,
341
+ )
342
+ nn.init.xavier_uniform_(self.w1.weight)
343
+ self.w2 = nn.Linear(
344
+ hidden_dim,
345
+ dim,
346
+ bias=False,
347
+ )
348
+ nn.init.xavier_uniform_(self.w2.weight)
349
+ self.w3 = nn.Linear(
350
+ dim,
351
+ hidden_dim,
352
+ bias=False,
353
+ )
354
+ nn.init.xavier_uniform_(self.w3.weight)
355
+
356
+ # @torch.compile
357
+ def _forward_silu_gating(self, x1, x3):
358
+ return F.silu(x1) * x3
359
+
360
+ def forward(self, x):
361
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
362
+
363
+
364
+ class JointTransformerBlock(nn.Module):
365
+ def __init__(
366
+ self,
367
+ layer_id: int,
368
+ dim: int,
369
+ n_heads: int,
370
+ n_kv_heads: int,
371
+ multiple_of: int,
372
+ ffn_dim_multiplier: float,
373
+ norm_eps: float,
374
+ qk_norm: bool,
375
+ modulation=True
376
+ ) -> None:
377
+ """
378
+ Initialize a TransformerBlock.
379
+
380
+ Args:
381
+ layer_id (int): Identifier for the layer.
382
+ dim (int): Embedding dimension of the input features.
383
+ n_heads (int): Number of attention heads.
384
+ n_kv_heads (Optional[int]): Number of attention heads in key and
385
+ value features (if using GQA), or set to None for the same as
386
+ query.
387
+ multiple_of (int):
388
+ ffn_dim_multiplier (float):
389
+ norm_eps (float):
390
+
391
+ """
392
+ super().__init__()
393
+ self.dim = dim
394
+ self.head_dim = dim // n_heads
395
+ self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
396
+ self.feed_forward = FeedForward(
397
+ dim=dim,
398
+ hidden_dim=4 * dim,
399
+ multiple_of=multiple_of,
400
+ ffn_dim_multiplier=ffn_dim_multiplier,
401
+ )
402
+ self.layer_id = layer_id
403
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
404
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
405
+
406
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
407
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
408
+
409
+ self.modulation = modulation
410
+ if modulation:
411
+ self.adaLN_modulation = nn.Sequential(
412
+ nn.SiLU(),
413
+ nn.Linear(
414
+ min(dim, 1024),
415
+ 4 * dim,
416
+ bias=True,
417
+ ),
418
+ )
419
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
420
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
421
+
422
+ def forward(
423
+ self,
424
+ x: torch.Tensor,
425
+ x_mask: torch.Tensor,
426
+ freqs_cis: torch.Tensor,
427
+ adaln_input: Optional[torch.Tensor]=None,
428
+ ):
429
+ """
430
+ Perform a forward pass through the TransformerBlock.
431
+
432
+ Args:
433
+ x (torch.Tensor): Input tensor.
434
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
435
+
436
+ Returns:
437
+ torch.Tensor: Output tensor after applying attention and
438
+ feedforward layers.
439
+
440
+ """
441
+ if self.modulation:
442
+ assert adaln_input is not None
443
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
444
+
445
+ x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
446
+ self.attention(
447
+ modulate(self.attention_norm1(x), scale_msa),
448
+ x_mask,
449
+ freqs_cis,
450
+ )
451
+ )
452
+ x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
453
+ self.feed_forward(
454
+ modulate(self.ffn_norm1(x), scale_mlp),
455
+ )
456
+ )
457
+ else:
458
+ assert adaln_input is None
459
+ x = x + self.attention_norm2(
460
+ self.attention(
461
+ self.attention_norm1(x),
462
+ x_mask,
463
+ freqs_cis,
464
+ )
465
+ )
466
+ x = x + self.ffn_norm2(
467
+ self.feed_forward(
468
+ self.ffn_norm1(x),
469
+ )
470
+ )
471
+ return x
472
+
473
+
474
+ class FinalLayer(nn.Module):
475
+ """
476
+ The final layer of NextDiT.
477
+ """
478
+
479
+ def __init__(self, hidden_size, patch_size, out_channels):
480
+ super().__init__()
481
+ self.norm_final = nn.LayerNorm(
482
+ hidden_size,
483
+ elementwise_affine=False,
484
+ eps=1e-6,
485
+ )
486
+ self.linear = nn.Linear(
487
+ hidden_size,
488
+ patch_size * patch_size * out_channels,
489
+ bias=True,
490
+ )
491
+ nn.init.zeros_(self.linear.weight)
492
+ nn.init.zeros_(self.linear.bias)
493
+
494
+ self.adaLN_modulation = nn.Sequential(
495
+ nn.SiLU(),
496
+ nn.Linear(
497
+ min(hidden_size, 1024),
498
+ hidden_size,
499
+ bias=True,
500
+ ),
501
+ )
502
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
503
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
504
+
505
+ def forward(self, x, c):
506
+ scale = self.adaLN_modulation(c)
507
+ x = modulate(self.norm_final(x), scale)
508
+ x = self.linear(x)
509
+ return x
510
+
511
+
512
+ class RopeEmbedder:
513
+ def __init__(
514
+ self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
515
+ ):
516
+ super().__init__()
517
+ self.theta = theta
518
+ self.axes_dims = axes_dims
519
+ self.axes_lens = axes_lens
520
+ self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
521
+
522
+ def __call__(self, ids: torch.Tensor):
523
+ self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
524
+ result = []
525
+ for i in range(len(self.axes_dims)):
526
+ # import torch.distributed as dist
527
+ # if not dist.is_initialized() or dist.get_rank() == 0:
528
+ # import pdb
529
+ # pdb.set_trace()
530
+ index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
531
+ result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
532
+ return torch.cat(result, dim=-1)
533
+
534
+
535
+ class NextDiT(nn.Module):
536
+ """
537
+ Diffusion model with a Transformer backbone.
538
+ """
539
+
540
+ def __init__(
541
+ self,
542
+ patch_size: int = 2,
543
+ in_channels: int = 4,
544
+ dim: int = 4096,
545
+ n_layers: int = 32,
546
+ n_refiner_layers: int = 2,
547
+ n_heads: int = 32,
548
+ n_kv_heads: Optional[int] = None,
549
+ multiple_of: int = 256,
550
+ ffn_dim_multiplier: Optional[float] = None,
551
+ norm_eps: float = 1e-5,
552
+ qk_norm: bool = False,
553
+ cap_feat_dim: int = 5120,
554
+ axes_dims: List[int] = (16, 56, 56),
555
+ axes_lens: List[int] = (1, 512, 512),
556
+ ) -> None:
557
+ super().__init__()
558
+ self.in_channels = in_channels
559
+ self.out_channels = in_channels
560
+ self.patch_size = patch_size
561
+
562
+ self.x_embedder = nn.Linear(
563
+ in_features=patch_size * patch_size * in_channels,
564
+ out_features=dim,
565
+ bias=True,
566
+ )
567
+ nn.init.xavier_uniform_(self.x_embedder.weight)
568
+ nn.init.constant_(self.x_embedder.bias, 0.0)
569
+
570
+ self.noise_refiner = nn.ModuleList(
571
+ [
572
+ JointTransformerBlock(
573
+ layer_id,
574
+ dim,
575
+ n_heads,
576
+ n_kv_heads,
577
+ multiple_of,
578
+ ffn_dim_multiplier,
579
+ norm_eps,
580
+ qk_norm,
581
+ modulation=True,
582
+ )
583
+ for layer_id in range(n_refiner_layers)
584
+ ]
585
+ )
586
+ self.context_refiner = nn.ModuleList(
587
+ [
588
+ JointTransformerBlock(
589
+ layer_id,
590
+ dim,
591
+ n_heads,
592
+ n_kv_heads,
593
+ multiple_of,
594
+ ffn_dim_multiplier,
595
+ norm_eps,
596
+ qk_norm,
597
+ modulation=False,
598
+ )
599
+ for layer_id in range(n_refiner_layers)
600
+ ]
601
+ )
602
+
603
+ self.t_embedder = TimestepEmbedder(min(dim, 1024))
604
+ self.cap_embedder = nn.Sequential(
605
+ RMSNorm(cap_feat_dim, eps=norm_eps),
606
+ nn.Linear(
607
+ cap_feat_dim,
608
+ dim,
609
+ bias=True,
610
+ ),
611
+ )
612
+ nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
613
+ # nn.init.zeros_(self.cap_embedder[1].weight)
614
+ nn.init.zeros_(self.cap_embedder[1].bias)
615
+
616
+ self.layers = nn.ModuleList(
617
+ [
618
+ JointTransformerBlock(
619
+ layer_id,
620
+ dim,
621
+ n_heads,
622
+ n_kv_heads,
623
+ multiple_of,
624
+ ffn_dim_multiplier,
625
+ norm_eps,
626
+ qk_norm,
627
+ )
628
+ for layer_id in range(n_layers)
629
+ ]
630
+ )
631
+ self.norm_final = RMSNorm(dim, eps=norm_eps)
632
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
633
+
634
+ assert (dim // n_heads) == sum(axes_dims)
635
+ self.axes_dims = axes_dims
636
+ self.axes_lens = axes_lens
637
+ self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
638
+ self.dim = dim
639
+ self.n_heads = n_heads
640
+
641
+ def unpatchify(
642
+ self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
643
+ ) -> List[torch.Tensor]:
644
+ """
645
+ x: (N, T, patch_size**2 * C)
646
+ imgs: (N, H, W, C)
647
+ """
648
+ pH = pW = self.patch_size
649
+ imgs = []
650
+ for i in range(x.size(0)):
651
+ H, W = img_size[i]
652
+ begin = cap_size[i]
653
+ end = begin + (H // pH) * (W // pW)
654
+ imgs.append(
655
+ x[i][begin:end]
656
+ .view(H // pH, W // pW, pH, pW, self.out_channels)
657
+ .permute(4, 0, 2, 1, 3)
658
+ .flatten(3, 4)
659
+ .flatten(1, 2)
660
+ )
661
+
662
+ if return_tensor:
663
+ imgs = torch.stack(imgs, dim=0)
664
+ return imgs
665
+
666
+ def patchify_and_embed(
667
+ self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor
668
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
669
+ bsz = len(x)
670
+ pH = pW = self.patch_size
671
+ device = x[0].device
672
+
673
+ l_effective_cap_len = cap_mask.sum(dim=1).tolist()
674
+ img_sizes = [(img.size(1), img.size(2)) for img in x]
675
+ l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
676
+
677
+ max_seq_len = max(
678
+ (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
679
+ )
680
+ max_cap_len = max(l_effective_cap_len)
681
+ max_img_len = max(l_effective_img_len)
682
+
683
+ position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
684
+
685
+ for i in range(bsz):
686
+ cap_len = l_effective_cap_len[i]
687
+ img_len = l_effective_img_len[i]
688
+ H, W = img_sizes[i]
689
+ H_tokens, W_tokens = H // pH, W // pW
690
+ assert H_tokens * W_tokens == img_len
691
+
692
+ position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
693
+ position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
694
+ row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
695
+ col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
696
+ position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
697
+ position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
698
+
699
+ freqs_cis = self.rope_embedder(position_ids)
700
+
701
+ # build freqs_cis for cap and image individually
702
+ cap_freqs_cis_shape = list(freqs_cis.shape)
703
+ # cap_freqs_cis_shape[1] = max_cap_len
704
+ cap_freqs_cis_shape[1] = cap_feats.shape[1]
705
+ cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
706
+
707
+ img_freqs_cis_shape = list(freqs_cis.shape)
708
+ img_freqs_cis_shape[1] = max_img_len
709
+ img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
710
+
711
+ for i in range(bsz):
712
+ cap_len = l_effective_cap_len[i]
713
+ img_len = l_effective_img_len[i]
714
+ cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
715
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
716
+
717
+ # refine context
718
+ for layer in self.context_refiner:
719
+ cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
720
+
721
+ # refine image
722
+ flat_x = []
723
+ for i in range(bsz):
724
+ img = x[i]
725
+ C, H, W = img.size()
726
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
727
+ flat_x.append(img)
728
+ x = flat_x
729
+ padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
730
+ padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device)
731
+ for i in range(bsz):
732
+ padded_img_embed[i, :l_effective_img_len[i]] = x[i]
733
+ padded_img_mask[i, :l_effective_img_len[i]] = True
734
+
735
+ padded_img_embed = self.x_embedder(padded_img_embed)
736
+ for layer in self.noise_refiner:
737
+ padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
738
+
739
+ mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
740
+ padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
741
+ for i in range(bsz):
742
+ cap_len = l_effective_cap_len[i]
743
+ img_len = l_effective_img_len[i]
744
+
745
+ mask[i, :cap_len+img_len] = True
746
+ padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
747
+ padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
748
+
749
+ return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
750
+
751
+
752
+ def forward(self, x, t, cap_feats, cap_mask):
753
+ """
754
+ Forward pass of NextDiT.
755
+ t: (N,) tensor of diffusion timesteps
756
+ y: (N,) tensor of text tokens/features
757
+ """
758
+
759
+ # import torch.distributed as dist
760
+ # if not dist.is_initialized() or dist.get_rank() == 0:
761
+ # import pdb
762
+ # pdb.set_trace()
763
+ # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt")
764
+ t = self.t_embedder(t) # (N, D)
765
+ adaln_input = t
766
+
767
+ cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
768
+
769
+ x_is_tensor = isinstance(x, torch.Tensor)
770
+ x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t)
771
+ freqs_cis = freqs_cis.to(x.device)
772
+
773
+ for layer in self.layers:
774
+ x = layer(x, mask, freqs_cis, adaln_input)
775
+
776
+ x = self.final_layer(x, adaln_input)
777
+ x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
778
+
779
+ return x
780
+
781
+ def forward_with_cfg(
782
+ self,
783
+ x,
784
+ t,
785
+ cap_feats,
786
+ cap_mask,
787
+ cfg_scale,
788
+ cfg_trunc=1,
789
+ renorm_cfg=1
790
+ ):
791
+ """
792
+ Forward pass of NextDiT, but also batches the unconditional forward pass
793
+ for classifier-free guidance.
794
+ """
795
+ # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
796
+ half = x[: len(x) // 2]
797
+ if t[0] < cfg_trunc:
798
+ combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
799
+ model_out = self.forward(combined, t, cap_feats, cap_mask) # [2, 16, 128, 128]
800
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
801
+ # three channels by default. The standard approach to cfg applies it to all channels.
802
+ # This can be done by uncommenting the following line and commenting-out the line following that.
803
+ eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
804
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
805
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
806
+ if float(renorm_cfg) > 0.0:
807
+ ori_pos_norm = torch.linalg.vector_norm(cond_eps
808
+ , dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
809
+ )
810
+ max_new_norm = ori_pos_norm * float(renorm_cfg)
811
+ new_pos_norm = torch.linalg.vector_norm(
812
+ half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
813
+ )
814
+ if new_pos_norm >= max_new_norm:
815
+ half_eps = half_eps * (max_new_norm / new_pos_norm)
816
+ else:
817
+ combined = half
818
+ model_out = self.forward(combined, t[:len(x) // 2], cap_feats[:len(x) // 2], cap_mask[:len(x) // 2])
819
+ eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
820
+ half_eps = eps
821
+
822
+ output = torch.cat([half_eps, half_eps], dim=0)
823
+ return output
824
+
825
+ @staticmethod
826
+ def precompute_freqs_cis(
827
+ dim: List[int],
828
+ end: List[int],
829
+ theta: float = 10000.0,
830
+ ):
831
+ """
832
+ Precompute the frequency tensor for complex exponentials (cis) with
833
+ given dimensions.
834
+
835
+ This function calculates a frequency tensor with complex exponentials
836
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
837
+ parameter scales the frequencies. The returned tensor contains complex
838
+ values in complex64 data type.
839
+
840
+ Args:
841
+ dim (list): Dimension of the frequency tensor.
842
+ end (list): End index for precomputing frequencies.
843
+ theta (float, optional): Scaling factor for frequency computation.
844
+ Defaults to 10000.0.
845
+
846
+ Returns:
847
+ torch.Tensor: Precomputed frequency tensor with complex
848
+ exponentials.
849
+ """
850
+ freqs_cis = []
851
+ for i, (d, e) in enumerate(zip(dim, end)):
852
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
853
+ timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
854
+ freqs = torch.outer(timestep, freqs).float()
855
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
856
+ freqs_cis.append(freqs_cis_i)
857
+
858
+ return freqs_cis
859
+
860
+ def parameter_count(self) -> int:
861
+ total_params = 0
862
+
863
+ def _recursive_count_params(module):
864
+ nonlocal total_params
865
+ for param in module.parameters(recurse=False):
866
+ total_params += param.numel()
867
+ for submodule in module.children():
868
+ _recursive_count_params(submodule)
869
+
870
+ _recursive_count_params(self)
871
+ return total_params
872
+
873
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
874
+ return list(self.layers)
875
+
876
+ def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
877
+ return list(self.layers)
878
+
879
+
880
+ #############################################################################
881
+ # NextDiT Configs #
882
+ #############################################################################
883
+
884
+ def NextDiT_2B_GQA_patch2_Adaln_Refiner(**kwargs):
885
+ return NextDiT(
886
+ patch_size=2,
887
+ dim=2304,
888
+ n_layers=26,
889
+ n_heads=24,
890
+ n_kv_heads=8,
891
+ axes_dims=[32, 32, 32],
892
+ axes_lens=[300, 512, 512],
893
+ **kwargs
894
+ )
895
+
896
+ def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs):
897
+ return NextDiT(
898
+ patch_size=2,
899
+ dim=2592,
900
+ n_layers=30,
901
+ n_heads=24,
902
+ n_kv_heads=8,
903
+ axes_dims=[36, 36, 36],
904
+ axes_lens=[300, 512, 512],
905
+ **kwargs,
906
+ )
907
+
908
+ def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs):
909
+ return NextDiT(
910
+ patch_size=2,
911
+ dim=2880,
912
+ n_layers=32,
913
+ n_heads=24,
914
+ n_kv_heads=8,
915
+ axes_dims=[40, 40, 40],
916
+ axes_lens=[300, 512, 512],
917
+ **kwargs,
918
+ )
919
+
920
+ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
921
+ return NextDiT(
922
+ patch_size=2,
923
+ dim=3840,
924
+ n_layers=32,
925
+ n_heads=32,
926
+ n_kv_heads=8,
927
+ axes_dims=[40, 40, 40],
928
+ axes_lens=[300, 512, 512],
929
+ **kwargs,
930
+ )