LPX55 commited on
Commit
3880905
·
verified ·
1 Parent(s): ddaa3ca

Create pipeline_hidream_image.py

Browse files
Files changed (1) hide show
  1. pipeline_hidream_image.py +526 -0
pipeline_hidream_image.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import einops
6
+ from einops import repeat
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
12
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
13
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
14
+ from models.embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed
15
+ from models.attention import HiDreamAttention, FeedForwardSwiGLU
16
+ from models.attention_processor import HiDreamAttnProcessor_flashattn
17
+ from models.moe import MOEFeedForwardSwiGLU
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ class TextProjection(nn.Module):
22
+ def __init__(self, in_features, hidden_size):
23
+ super().__init__()
24
+ self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
25
+
26
+ def forward(self, caption):
27
+ hidden_states = self.linear(caption)
28
+ return hidden_states
29
+
30
+ class BlockType:
31
+ TransformerBlock = 1
32
+ SingleTransformerBlock = 2
33
+
34
+ @maybe_allow_in_graph
35
+ class HiDreamImageSingleTransformerBlock(nn.Module):
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ num_attention_heads: int,
40
+ attention_head_dim: int,
41
+ num_routed_experts: int = 4,
42
+ num_activated_experts: int = 2
43
+ ):
44
+ super().__init__()
45
+ self.num_attention_heads = num_attention_heads
46
+ self.adaLN_modulation = nn.Sequential(
47
+ nn.SiLU(),
48
+ nn.Linear(dim, 6 * dim, bias=True)
49
+ )
50
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
51
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
52
+
53
+ # 1. Attention
54
+ self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
55
+ self.attn1 = HiDreamAttention(
56
+ query_dim=dim,
57
+ heads=num_attention_heads,
58
+ dim_head=attention_head_dim,
59
+ processor = HiDreamAttnProcessor_flashattn(),
60
+ single = True
61
+ )
62
+
63
+ # 3. Feed-forward
64
+ self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
65
+ if num_routed_experts > 0:
66
+ self.ff_i = MOEFeedForwardSwiGLU(
67
+ dim = dim,
68
+ hidden_dim = 4 * dim,
69
+ num_routed_experts = num_routed_experts,
70
+ num_activated_experts = num_activated_experts,
71
+ )
72
+ else:
73
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
74
+
75
+ def forward(
76
+ self,
77
+ image_tokens: torch.FloatTensor,
78
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
79
+ text_tokens: Optional[torch.FloatTensor] = None,
80
+ adaln_input: Optional[torch.FloatTensor] = None,
81
+ rope: torch.FloatTensor = None,
82
+
83
+ ) -> torch.FloatTensor:
84
+ wtype = image_tokens.dtype
85
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
86
+ self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
87
+
88
+ # 1. MM-Attention
89
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
90
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
91
+ attn_output_i = self.attn1(
92
+ norm_image_tokens,
93
+ image_tokens_masks,
94
+ rope = rope,
95
+ )
96
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
97
+
98
+ # 2. Feed-forward
99
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
100
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
101
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
102
+ image_tokens = ff_output_i + image_tokens
103
+ return image_tokens
104
+
105
+ @maybe_allow_in_graph
106
+ class HiDreamImageTransformerBlock(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim: int,
110
+ num_attention_heads: int,
111
+ attention_head_dim: int,
112
+ num_routed_experts: int = 4,
113
+ num_activated_experts: int = 2
114
+ ):
115
+ super().__init__()
116
+ self.num_attention_heads = num_attention_heads
117
+ self.adaLN_modulation = nn.Sequential(
118
+ nn.SiLU(),
119
+ nn.Linear(dim, 12 * dim, bias=True)
120
+ )
121
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
122
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
123
+
124
+ # 1. Attention
125
+ self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
126
+ self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
127
+ self.attn1 = HiDreamAttention(
128
+ query_dim=dim,
129
+ heads=num_attention_heads,
130
+ dim_head=attention_head_dim,
131
+ processor = HiDreamAttnProcessor_flashattn(),
132
+ single = False
133
+ )
134
+
135
+ # 3. Feed-forward
136
+ self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
137
+ if num_routed_experts > 0:
138
+ self.ff_i = MOEFeedForwardSwiGLU(
139
+ dim = dim,
140
+ hidden_dim = 4 * dim,
141
+ num_routed_experts = num_routed_experts,
142
+ num_activated_experts = num_activated_experts,
143
+ )
144
+ else:
145
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
146
+ self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
147
+ self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
148
+
149
+ def forward(
150
+ self,
151
+ image_tokens: torch.FloatTensor,
152
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
153
+ text_tokens: Optional[torch.FloatTensor] = None,
154
+ adaln_input: Optional[torch.FloatTensor] = None,
155
+ rope: torch.FloatTensor = None,
156
+ ) -> torch.FloatTensor:
157
+ wtype = image_tokens.dtype
158
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
159
+ shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
160
+ self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
161
+
162
+ # 1. MM-Attention
163
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
164
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
165
+ norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
166
+ norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
167
+
168
+ attn_output_i, attn_output_t = self.attn1(
169
+ norm_image_tokens,
170
+ image_tokens_masks,
171
+ norm_text_tokens,
172
+ rope = rope,
173
+ )
174
+
175
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
176
+ text_tokens = gate_msa_t * attn_output_t + text_tokens
177
+
178
+ # 2. Feed-forward
179
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
180
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
181
+ norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
182
+ norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
183
+
184
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
185
+ ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
186
+ image_tokens = ff_output_i + image_tokens
187
+ text_tokens = ff_output_t + text_tokens
188
+ return image_tokens, text_tokens
189
+
190
+ @maybe_allow_in_graph
191
+ class HiDreamImageBlock(nn.Module):
192
+ def __init__(
193
+ self,
194
+ dim: int,
195
+ num_attention_heads: int,
196
+ attention_head_dim: int,
197
+ num_routed_experts: int = 4,
198
+ num_activated_experts: int = 2,
199
+ block_type: BlockType = BlockType.TransformerBlock,
200
+ ):
201
+ super().__init__()
202
+ block_classes = {
203
+ BlockType.TransformerBlock: HiDreamImageTransformerBlock,
204
+ BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
205
+ }
206
+ self.block = block_classes[block_type](
207
+ dim,
208
+ num_attention_heads,
209
+ attention_head_dim,
210
+ num_routed_experts,
211
+ num_activated_experts
212
+ )
213
+
214
+ def forward(
215
+ self,
216
+ image_tokens: torch.FloatTensor,
217
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
218
+ text_tokens: Optional[torch.FloatTensor] = None,
219
+ adaln_input: torch.FloatTensor = None,
220
+ rope: torch.FloatTensor = None,
221
+ ) -> torch.FloatTensor:
222
+ return self.block(
223
+ image_tokens,
224
+ image_tokens_masks,
225
+ text_tokens,
226
+ adaln_input,
227
+ rope,
228
+ )
229
+
230
+ class HiDreamImageTransformer2DModel(
231
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
232
+ ):
233
+ _supports_gradient_checkpointing = True
234
+ _no_split_modules = ["HiDreamImageBlock"]
235
+
236
+ @register_to_config
237
+ def __init__(
238
+ self,
239
+ patch_size: Optional[int] = None,
240
+ in_channels: int = 64,
241
+ out_channels: Optional[int] = None,
242
+ num_layers: int = 16,
243
+ num_single_layers: int = 32,
244
+ attention_head_dim: int = 128,
245
+ num_attention_heads: int = 20,
246
+ caption_channels: List[int] = None,
247
+ text_emb_dim: int = 2048,
248
+ num_routed_experts: int = 4,
249
+ num_activated_experts: int = 2,
250
+ axes_dims_rope: Tuple[int, int] = (32, 32),
251
+ max_resolution: Tuple[int, int] = (128, 128),
252
+ llama_layers: List[int] = None,
253
+ ):
254
+ super().__init__()
255
+ self.out_channels = out_channels or in_channels
256
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
257
+ self.llama_layers = llama_layers
258
+
259
+ self.t_embedder = TimestepEmbed(self.inner_dim)
260
+ self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim)
261
+ self.x_embedder = PatchEmbed(
262
+ patch_size = patch_size,
263
+ in_channels = in_channels,
264
+ out_channels = self.inner_dim,
265
+ )
266
+ self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
267
+
268
+ self.double_stream_blocks = nn.ModuleList(
269
+ [
270
+ HiDreamImageBlock(
271
+ dim = self.inner_dim,
272
+ num_attention_heads = self.config.num_attention_heads,
273
+ attention_head_dim = self.config.attention_head_dim,
274
+ num_routed_experts = num_routed_experts,
275
+ num_activated_experts = num_activated_experts,
276
+ block_type = BlockType.TransformerBlock
277
+ )
278
+ for i in range(self.config.num_layers)
279
+ ]
280
+ )
281
+
282
+ self.single_stream_blocks = nn.ModuleList(
283
+ [
284
+ HiDreamImageBlock(
285
+ dim = self.inner_dim,
286
+ num_attention_heads = self.config.num_attention_heads,
287
+ attention_head_dim = self.config.attention_head_dim,
288
+ num_routed_experts = num_routed_experts,
289
+ num_activated_experts = num_activated_experts,
290
+ block_type = BlockType.SingleTransformerBlock
291
+ )
292
+ for i in range(self.config.num_single_layers)
293
+ ]
294
+ )
295
+
296
+ self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels)
297
+
298
+ caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
299
+ caption_projection = []
300
+ for caption_channel in caption_channels:
301
+ caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim))
302
+ self.caption_projection = nn.ModuleList(caption_projection)
303
+ self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
304
+
305
+ self.gradient_checkpointing = False
306
+
307
+ def _set_gradient_checkpointing(self, module, value=False):
308
+ if hasattr(module, "gradient_checkpointing"):
309
+ module.gradient_checkpointing = value
310
+
311
+ def expand_timesteps(self, timesteps, batch_size, device):
312
+ if not torch.is_tensor(timesteps):
313
+ is_mps = device.type == "mps"
314
+ if isinstance(timesteps, float):
315
+ dtype = torch.float32 if is_mps else torch.float64
316
+ else:
317
+ dtype = torch.int32 if is_mps else torch.int64
318
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
319
+ elif len(timesteps.shape) == 0:
320
+ timesteps = timesteps[None].to(device)
321
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
322
+ timesteps = timesteps.expand(batch_size)
323
+ return timesteps
324
+
325
+ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
326
+ if is_training:
327
+ x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
328
+ else:
329
+ x_arr = []
330
+ for i, img_size in enumerate(img_sizes):
331
+ pH, pW = img_size
332
+ x_arr.append(
333
+ einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
334
+ p1=self.config.patch_size, p2=self.config.patch_size)
335
+ )
336
+ x = torch.cat(x_arr, dim=0)
337
+ return x
338
+
339
+ def patchify(self, x, max_seq, img_sizes=None):
340
+ pz2 = self.config.patch_size * self.config.patch_size
341
+ if isinstance(x, torch.Tensor):
342
+ B, C = x.shape[0], x.shape[1]
343
+ device = x.device
344
+ dtype = x.dtype
345
+ else:
346
+ B, C = len(x), x[0].shape[0]
347
+ device = x[0].device
348
+ dtype = x[0].dtype
349
+ x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
350
+
351
+ if img_sizes is not None:
352
+ for i, img_size in enumerate(img_sizes):
353
+ x_masks[i, 0:img_size[0] * img_size[1]] = 1
354
+ x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
355
+ elif isinstance(x, torch.Tensor):
356
+ pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size
357
+ x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size)
358
+ img_sizes = [[pH, pW]] * B
359
+ x_masks = None
360
+ else:
361
+ raise NotImplementedError
362
+ return x, x_masks, img_sizes
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ timesteps: torch.LongTensor = None,
368
+ encoder_hidden_states: torch.Tensor = None,
369
+ pooled_embeds: torch.Tensor = None,
370
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
371
+ img_ids: Optional[torch.Tensor] = None,
372
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
373
+ return_dict: bool = True,
374
+ ):
375
+ if joint_attention_kwargs is not None:
376
+ joint_attention_kwargs = joint_attention_kwargs.copy()
377
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
378
+ else:
379
+ lora_scale = 1.0
380
+
381
+ if USE_PEFT_BACKEND:
382
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
383
+ scale_lora_layers(self, lora_scale)
384
+ else:
385
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
386
+ logger.warning(
387
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
388
+ )
389
+
390
+ # spatial forward
391
+ batch_size = hidden_states.shape[0]
392
+ hidden_states_type = hidden_states.dtype
393
+
394
+ # 0. time
395
+ timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
396
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
397
+ p_embedder = self.p_embedder(pooled_embeds)
398
+ adaln_input = timesteps + p_embedder
399
+
400
+ hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
401
+ if image_tokens_masks is None:
402
+ pH, pW = img_sizes[0]
403
+ img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
404
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
405
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
406
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
407
+ hidden_states = self.x_embedder(hidden_states)
408
+
409
+ T5_encoder_hidden_states = encoder_hidden_states[0]
410
+ encoder_hidden_states = encoder_hidden_states[-1]
411
+ encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
412
+
413
+ if self.caption_projection is not None:
414
+ new_encoder_hidden_states = []
415
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
416
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
417
+ enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
418
+ new_encoder_hidden_states.append(enc_hidden_state)
419
+ encoder_hidden_states = new_encoder_hidden_states
420
+ T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
421
+ T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
422
+ encoder_hidden_states.append(T5_encoder_hidden_states)
423
+
424
+ txt_ids = torch.zeros(
425
+ batch_size,
426
+ encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
427
+ 3,
428
+ device=img_ids.device, dtype=img_ids.dtype
429
+ )
430
+ ids = torch.cat((img_ids, txt_ids), dim=1)
431
+ rope = self.pe_embedder(ids)
432
+
433
+ # 2. Blocks
434
+ block_id = 0
435
+ initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
436
+ initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
437
+ for bid, block in enumerate(self.double_stream_blocks):
438
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
439
+ cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
440
+ if self.training and self.gradient_checkpointing:
441
+ def create_custom_forward(module, return_dict=None):
442
+ def custom_forward(*inputs):
443
+ if return_dict is not None:
444
+ return module(*inputs, return_dict=return_dict)
445
+ else:
446
+ return module(*inputs)
447
+ return custom_forward
448
+
449
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
450
+ hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
451
+ create_custom_forward(block),
452
+ hidden_states,
453
+ image_tokens_masks,
454
+ cur_encoder_hidden_states,
455
+ adaln_input,
456
+ rope,
457
+ **ckpt_kwargs,
458
+ )
459
+ else:
460
+ hidden_states, initial_encoder_hidden_states = block(
461
+ image_tokens = hidden_states,
462
+ image_tokens_masks = image_tokens_masks,
463
+ text_tokens = cur_encoder_hidden_states,
464
+ adaln_input = adaln_input,
465
+ rope = rope,
466
+ )
467
+ initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
468
+ block_id += 1
469
+
470
+ image_tokens_seq_len = hidden_states.shape[1]
471
+ hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
472
+ hidden_states_seq_len = hidden_states.shape[1]
473
+ if image_tokens_masks is not None:
474
+ encoder_attention_mask_ones = torch.ones(
475
+ (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
476
+ device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
477
+ )
478
+ image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
479
+
480
+ for bid, block in enumerate(self.single_stream_blocks):
481
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
482
+ hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
483
+ if self.training and self.gradient_checkpointing:
484
+ def create_custom_forward(module, return_dict=None):
485
+ def custom_forward(*inputs):
486
+ if return_dict is not None:
487
+ return module(*inputs, return_dict=return_dict)
488
+ else:
489
+ return module(*inputs)
490
+ return custom_forward
491
+
492
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
493
+ hidden_states = torch.utils.checkpoint.checkpoint(
494
+ create_custom_forward(block),
495
+ hidden_states,
496
+ image_tokens_masks,
497
+ None,
498
+ adaln_input,
499
+ rope,
500
+ **ckpt_kwargs,
501
+ )
502
+ else:
503
+ hidden_states = block(
504
+ image_tokens = hidden_states,
505
+ image_tokens_masks = image_tokens_masks,
506
+ text_tokens = None,
507
+ adaln_input = adaln_input,
508
+ rope = rope,
509
+ )
510
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
511
+ block_id += 1
512
+
513
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
514
+ output = self.final_layer(hidden_states, adaln_input)
515
+ output = self.unpatchify(output, img_sizes, self.training)
516
+ if image_tokens_masks is not None:
517
+ image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len]
518
+
519
+ if USE_PEFT_BACKEND:
520
+ # remove `lora_scale` from each PEFT layer
521
+ unscale_lora_layers(self, lora_scale)
522
+
523
+ if not return_dict:
524
+ return (output, image_tokens_masks)
525
+ return Transformer2DModelOutput(sample=output, mask=image_tokens_masks)
526
+