davidvgilmore commited on
Commit
83f798d
·
verified ·
1 Parent(s): 1e5b20c

Upload hy3dgen/texgen/hunyuanpaint/unet/modules.py with huggingface_hub

Browse files
hy3dgen/texgen/hunyuanpaint/unet/modules.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+
26
+ import copy
27
+ import json
28
+ import os
29
+ from typing import Any, Dict, Optional
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ from diffusers.models import UNet2DConditionModel
34
+ from diffusers.models.attention_processor import Attention
35
+ from diffusers.models.transformers.transformer_2d import BasicTransformerBlock
36
+ from einops import rearrange
37
+
38
+
39
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
40
+ # "feed_forward_chunk_size" can be used to save memory
41
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
42
+ raise ValueError(
43
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
44
+ )
45
+
46
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
47
+ ff_output = torch.cat(
48
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
49
+ dim=chunk_dim,
50
+ )
51
+ return ff_output
52
+
53
+
54
+ class Basic2p5DTransformerBlock(torch.nn.Module):
55
+ def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ma=True, use_ra=True) -> None:
56
+ super().__init__()
57
+ self.transformer = transformer
58
+ self.layer_name = layer_name
59
+ self.use_ma = use_ma
60
+ self.use_ra = use_ra
61
+
62
+ # multiview attn
63
+ if self.use_ma:
64
+ self.attn_multiview = Attention(
65
+ query_dim=self.dim,
66
+ heads=self.num_attention_heads,
67
+ dim_head=self.attention_head_dim,
68
+ dropout=self.dropout,
69
+ bias=self.attention_bias,
70
+ cross_attention_dim=None,
71
+ upcast_attention=self.attn1.upcast_attention,
72
+ out_bias=True,
73
+ )
74
+
75
+ # ref attn
76
+ if self.use_ra:
77
+ self.attn_refview = Attention(
78
+ query_dim=self.dim,
79
+ heads=self.num_attention_heads,
80
+ dim_head=self.attention_head_dim,
81
+ dropout=self.dropout,
82
+ bias=self.attention_bias,
83
+ cross_attention_dim=None,
84
+ upcast_attention=self.attn1.upcast_attention,
85
+ out_bias=True,
86
+ )
87
+
88
+ def __getattr__(self, name: str):
89
+ try:
90
+ return super().__getattr__(name)
91
+ except AttributeError:
92
+ return getattr(self.transformer, name)
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ encoder_hidden_states: Optional[torch.Tensor] = None,
99
+ encoder_attention_mask: Optional[torch.Tensor] = None,
100
+ timestep: Optional[torch.LongTensor] = None,
101
+ cross_attention_kwargs: Dict[str, Any] = None,
102
+ class_labels: Optional[torch.LongTensor] = None,
103
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
104
+ ) -> torch.Tensor:
105
+
106
+ # Notice that normalization is always applied before the real computation in the following blocks.
107
+ # 0. Self-Attention
108
+ batch_size = hidden_states.shape[0]
109
+
110
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
111
+ num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
112
+ mode = cross_attention_kwargs.pop('mode', None)
113
+ mva_scale = cross_attention_kwargs.pop('mva_scale', 1.0)
114
+ ref_scale = cross_attention_kwargs.pop('ref_scale', 1.0)
115
+ condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
116
+
117
+ if self.norm_type == "ada_norm":
118
+ norm_hidden_states = self.norm1(hidden_states, timestep)
119
+ elif self.norm_type == "ada_norm_zero":
120
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
121
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
122
+ )
123
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
124
+ norm_hidden_states = self.norm1(hidden_states)
125
+ elif self.norm_type == "ada_norm_continuous":
126
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
127
+ elif self.norm_type == "ada_norm_single":
128
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
129
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
130
+ ).chunk(6, dim=1)
131
+ norm_hidden_states = self.norm1(hidden_states)
132
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
133
+ else:
134
+ raise ValueError("Incorrect norm used")
135
+
136
+ if self.pos_embed is not None:
137
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
138
+
139
+ # 1. Prepare GLIGEN inputs
140
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
141
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
142
+
143
+ attn_output = self.attn1(
144
+ norm_hidden_states,
145
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
146
+ attention_mask=attention_mask,
147
+ **cross_attention_kwargs,
148
+ )
149
+
150
+ if self.norm_type == "ada_norm_zero":
151
+ attn_output = gate_msa.unsqueeze(1) * attn_output
152
+ elif self.norm_type == "ada_norm_single":
153
+ attn_output = gate_msa * attn_output
154
+
155
+ hidden_states = attn_output + hidden_states
156
+ if hidden_states.ndim == 4:
157
+ hidden_states = hidden_states.squeeze(1)
158
+
159
+ # 1.2 Reference Attention
160
+ if 'w' in mode:
161
+ condition_embed_dict[self.layer_name] = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c',
162
+ n=num_in_batch) # B, (N L), C
163
+
164
+ if 'r' in mode and self.use_ra:
165
+ condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1, num_in_batch, 1,
166
+ 1) # B N L C
167
+ condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
168
+
169
+ attn_output = self.attn_refview(
170
+ norm_hidden_states,
171
+ encoder_hidden_states=condition_embed,
172
+ attention_mask=None,
173
+ **cross_attention_kwargs
174
+ )
175
+ ref_scale_timing = ref_scale
176
+ if isinstance(ref_scale, torch.Tensor):
177
+ ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch).view(-1)
178
+ for _ in range(attn_output.ndim - 1):
179
+ ref_scale_timing = ref_scale_timing.unsqueeze(-1)
180
+ hidden_states = ref_scale_timing * attn_output + hidden_states
181
+ if hidden_states.ndim == 4:
182
+ hidden_states = hidden_states.squeeze(1)
183
+
184
+ # 1.3 Multiview Attention
185
+ if num_in_batch > 1 and self.use_ma:
186
+ multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
187
+
188
+ attn_output = self.attn_multiview(
189
+ multivew_hidden_states,
190
+ encoder_hidden_states=multivew_hidden_states,
191
+ **cross_attention_kwargs
192
+ )
193
+
194
+ attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch)
195
+
196
+ hidden_states = mva_scale * attn_output + hidden_states
197
+ if hidden_states.ndim == 4:
198
+ hidden_states = hidden_states.squeeze(1)
199
+
200
+ # 1.2 GLIGEN Control
201
+ if gligen_kwargs is not None:
202
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
203
+
204
+ # 3. Cross-Attention
205
+ if self.attn2 is not None:
206
+ if self.norm_type == "ada_norm":
207
+ norm_hidden_states = self.norm2(hidden_states, timestep)
208
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
209
+ norm_hidden_states = self.norm2(hidden_states)
210
+ elif self.norm_type == "ada_norm_single":
211
+ # For PixArt norm2 isn't applied here:
212
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
213
+ norm_hidden_states = hidden_states
214
+ elif self.norm_type == "ada_norm_continuous":
215
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
216
+ else:
217
+ raise ValueError("Incorrect norm")
218
+
219
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
220
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
221
+
222
+ attn_output = self.attn2(
223
+ norm_hidden_states,
224
+ encoder_hidden_states=encoder_hidden_states,
225
+ attention_mask=encoder_attention_mask,
226
+ **cross_attention_kwargs,
227
+ )
228
+
229
+ hidden_states = attn_output + hidden_states
230
+
231
+ # 4. Feed-forward
232
+ # i2vgen doesn't have this norm 🤷‍♂️
233
+ if self.norm_type == "ada_norm_continuous":
234
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
235
+ elif not self.norm_type == "ada_norm_single":
236
+ norm_hidden_states = self.norm3(hidden_states)
237
+
238
+ if self.norm_type == "ada_norm_zero":
239
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
240
+
241
+ if self.norm_type == "ada_norm_single":
242
+ norm_hidden_states = self.norm2(hidden_states)
243
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
244
+
245
+ if self._chunk_size is not None:
246
+ # "feed_forward_chunk_size" can be used to save memory
247
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
248
+ else:
249
+ ff_output = self.ff(norm_hidden_states)
250
+
251
+ if self.norm_type == "ada_norm_zero":
252
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
253
+ elif self.norm_type == "ada_norm_single":
254
+ ff_output = gate_mlp * ff_output
255
+
256
+ hidden_states = ff_output + hidden_states
257
+ if hidden_states.ndim == 4:
258
+ hidden_states = hidden_states.squeeze(1)
259
+
260
+ return hidden_states
261
+
262
+
263
+ class UNet2p5DConditionModel(torch.nn.Module):
264
+ def __init__(self, unet: UNet2DConditionModel) -> None:
265
+ super().__init__()
266
+ self.unet = unet
267
+
268
+ self.use_ma = True
269
+ self.use_ra = True
270
+ self.use_camera_embedding = True
271
+ self.use_dual_stream = True
272
+
273
+ if self.use_dual_stream:
274
+ self.unet_dual = copy.deepcopy(unet)
275
+ self.init_attention(self.unet_dual)
276
+ self.init_attention(self.unet, use_ma=self.use_ma, use_ra=self.use_ra)
277
+ self.init_condition()
278
+ self.init_camera_embedding()
279
+
280
+ @staticmethod
281
+ def from_pretrained(pretrained_model_name_or_path, **kwargs):
282
+ torch_dtype = kwargs.pop('torch_dtype', torch.float32)
283
+ config_path = os.path.join(pretrained_model_name_or_path, 'config.json')
284
+ unet_ckpt_path = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin')
285
+ with open(config_path, 'r', encoding='utf-8') as file:
286
+ config = json.load(file)
287
+ unet = UNet2DConditionModel(**config)
288
+ unet = UNet2p5DConditionModel(unet)
289
+ unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
290
+ unet.load_state_dict(unet_ckpt, strict=True)
291
+ unet = unet.to(torch_dtype)
292
+ return unet
293
+
294
+ def init_condition(self):
295
+ self.unet.conv_in = torch.nn.Conv2d(
296
+ 12,
297
+ self.unet.conv_in.out_channels,
298
+ kernel_size=self.unet.conv_in.kernel_size,
299
+ stride=self.unet.conv_in.stride,
300
+ padding=self.unet.conv_in.padding,
301
+ dilation=self.unet.conv_in.dilation,
302
+ groups=self.unet.conv_in.groups,
303
+ bias=self.unet.conv_in.bias is not None)
304
+
305
+ self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1, 77, 1024))
306
+ self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1, 77, 1024))
307
+
308
+ def init_camera_embedding(self):
309
+
310
+ if self.use_camera_embedding:
311
+ time_embed_dim = 1280
312
+ self.max_num_ref_image = 5
313
+ self.max_num_gen_image = 12 * 3 + 4 * 2
314
+ self.unet.class_embedding = nn.Embedding(self.max_num_ref_image + self.max_num_gen_image, time_embed_dim)
315
+
316
+ def init_attention(self, unet, use_ma=False, use_ra=False):
317
+
318
+ for down_block_i, down_block in enumerate(unet.down_blocks):
319
+ if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
320
+ for attn_i, attn in enumerate(down_block.attentions):
321
+ for transformer_i, transformer in enumerate(attn.transformer_blocks):
322
+ if isinstance(transformer, BasicTransformerBlock):
323
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer,
324
+ f'down_{down_block_i}_{attn_i}_{transformer_i}',
325
+ use_ma, use_ra)
326
+
327
+ if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
328
+ for attn_i, attn in enumerate(unet.mid_block.attentions):
329
+ for transformer_i, transformer in enumerate(attn.transformer_blocks):
330
+ if isinstance(transformer, BasicTransformerBlock):
331
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer,
332
+ f'mid_{attn_i}_{transformer_i}',
333
+ use_ma, use_ra)
334
+
335
+ for up_block_i, up_block in enumerate(unet.up_blocks):
336
+ if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
337
+ for attn_i, attn in enumerate(up_block.attentions):
338
+ for transformer_i, transformer in enumerate(attn.transformer_blocks):
339
+ if isinstance(transformer, BasicTransformerBlock):
340
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer,
341
+ f'up_{up_block_i}_{attn_i}_{transformer_i}',
342
+ use_ma, use_ra)
343
+
344
+ def __getattr__(self, name: str):
345
+ try:
346
+ return super().__getattr__(name)
347
+ except AttributeError:
348
+ return getattr(self.unet, name)
349
+
350
+ def forward(
351
+ self, sample, timestep, encoder_hidden_states,
352
+ *args, down_intrablock_additional_residuals=None,
353
+ down_block_res_samples=None, mid_block_res_sample=None,
354
+ **cached_condition,
355
+ ):
356
+ B, N_gen, _, H, W = sample.shape
357
+ assert H == W
358
+
359
+ if self.use_camera_embedding:
360
+ camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image
361
+ camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)')
362
+ else:
363
+ camera_info_gen = None
364
+
365
+ sample = [sample]
366
+ if 'normal_imgs' in cached_condition:
367
+ sample.append(cached_condition["normal_imgs"])
368
+ if 'position_imgs' in cached_condition:
369
+ sample.append(cached_condition["position_imgs"])
370
+ sample = torch.cat(sample, dim=2)
371
+
372
+ sample = rearrange(sample, 'b n c h w -> (b n) c h w')
373
+
374
+ encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
375
+ encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
376
+
377
+ if self.use_ra:
378
+ if 'condition_embed_dict' in cached_condition:
379
+ condition_embed_dict = cached_condition['condition_embed_dict']
380
+ else:
381
+ condition_embed_dict = {}
382
+ ref_latents = cached_condition['ref_latents']
383
+ N_ref = ref_latents.shape[1]
384
+ if self.use_camera_embedding:
385
+ camera_info_ref = cached_condition['camera_info_ref']
386
+ camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
387
+ else:
388
+ camera_info_ref = None
389
+
390
+ ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
391
+
392
+ encoder_hidden_states_ref = self.unet.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1)
393
+ encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c')
394
+
395
+ noisy_ref_latents = ref_latents
396
+ timestep_ref = 0
397
+
398
+ if self.use_dual_stream:
399
+ unet_ref = self.unet_dual
400
+ else:
401
+ unet_ref = self.unet
402
+ unet_ref(
403
+ noisy_ref_latents, timestep_ref,
404
+ encoder_hidden_states=encoder_hidden_states_ref,
405
+ class_labels=camera_info_ref,
406
+ # **kwargs
407
+ return_dict=False,
408
+ cross_attention_kwargs={
409
+ 'mode': 'w', 'num_in_batch': N_ref,
410
+ 'condition_embed_dict': condition_embed_dict},
411
+ )
412
+ cached_condition['condition_embed_dict'] = condition_embed_dict
413
+ else:
414
+ condition_embed_dict = None
415
+
416
+ mva_scale = cached_condition.get('mva_scale', 1.0)
417
+ ref_scale = cached_condition.get('ref_scale', 1.0)
418
+
419
+ return self.unet(
420
+ sample, timestep,
421
+ encoder_hidden_states_gen, *args,
422
+ class_labels=camera_info_gen,
423
+ down_intrablock_additional_residuals=[
424
+ sample.to(dtype=self.unet.dtype) for sample in down_intrablock_additional_residuals
425
+ ] if down_intrablock_additional_residuals is not None else None,
426
+ down_block_additional_residuals=[
427
+ sample.to(dtype=self.unet.dtype) for sample in down_block_res_samples
428
+ ] if down_block_res_samples is not None else None,
429
+ mid_block_additional_residual=(
430
+ mid_block_res_sample.to(dtype=self.unet.dtype)
431
+ if mid_block_res_sample is not None else None
432
+ ),
433
+ return_dict=False,
434
+ cross_attention_kwargs={
435
+ 'mode': 'r', 'num_in_batch': N_gen,
436
+ 'condition_embed_dict': condition_embed_dict,
437
+ 'mva_scale': mva_scale,
438
+ 'ref_scale': ref_scale,
439
+ },
440
+ )