Image-to-3D
Hunyuan3D-2
Diffusers
Safetensors
English
Chinese
text-to-3d
SeanYoungxh ShaunSZ commited on
Commit
efb40d8
·
verified ·
1 Parent(s): 8878bbe

Upload 19 files (#36)

Browse files

- Upload 19 files (7c284a50093e4e48e1d4bde19eff8b176bb2ed5b)


Co-authored-by: zz <[email protected]>

hunyuan3d-paint-v2-0-turbo/unet/diffusion_pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:690a5fc63c4e263ba07dd41dc95f86f702c059c4361b863e2e21af88d8f75714
3
- size 3722674238
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24e7f1aea8a7c94cee627eb06f5265f19eeff4e19568636c5eaef050cc19ba3d
3
+ size 7325432923
hunyuan3d-paint-v2-0-turbo/unet/modules.py CHANGED
@@ -22,7 +22,6 @@
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
25
-
26
  import copy
27
  import json
28
  import os
@@ -41,7 +40,9 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
41
  # "feed_forward_chunk_size" can be used to save memory
42
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
43
  raise ValueError(
44
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
 
 
45
  )
46
 
47
  num_chunks = hidden_states.shape[chunk_dim] // chunk_size
@@ -51,329 +52,16 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
51
  )
52
  return ff_output
53
 
54
- class PoseRoPEAttnProcessor2_0:
55
- r"""
56
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
57
- """
58
-
59
- def __init__(self):
60
- if not hasattr(F, "scaled_dot_product_attention"):
61
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
62
-
63
- def get_1d_rotary_pos_embed(
64
- self,
65
- dim: int,
66
- pos: torch.Tensor,
67
- theta: float = 10000.0,
68
- linear_factor=1.0,
69
- ntk_factor=1.0,
70
- ):
71
- assert dim % 2 == 0
72
-
73
- theta = theta * ntk_factor
74
- freqs = (
75
- 1.0
76
- / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
77
- / linear_factor
78
- ) # [D/2]
79
- freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
80
- # flux, hunyuan-dit, cogvideox
81
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
82
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
83
- return freqs_cos, freqs_sin
84
-
85
-
86
- def get_3d_rotary_pos_embed(
87
- self,
88
- position,
89
- embed_dim,
90
- voxel_resolution,
91
- theta: int = 10000,
92
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
93
- """
94
- RoPE for video tokens with 3D structure.
95
-
96
- Args:
97
- voxel_resolution (`int`):
98
- The grid size of the spatial positional embedding (height, width).
99
- theta (`float`):
100
- Scaling factor for frequency computation.
101
-
102
- Returns:
103
- `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
104
- """
105
- assert position.shape[-1]==3
106
-
107
- # Compute dimensions for each axis
108
- dim_xy = embed_dim // 8 * 3
109
- dim_z = embed_dim // 8 * 2
110
-
111
- # Temporal frequencies
112
- grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
113
- freqs_xy = self.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
114
- freqs_z = self.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
115
-
116
- xy_cos, xy_sin = freqs_xy # both t_cos and t_sin has shape: voxel_resolution, dim_xy
117
- z_cos, z_sin = freqs_z # both w_cos and w_sin has shape: voxel_resolution, dim_z
118
-
119
- embed_flattn = position.view(-1, position.shape[-1])
120
- x_cos = xy_cos[embed_flattn[:,0], :]
121
- x_sin = xy_sin[embed_flattn[:,0], :]
122
- y_cos = xy_cos[embed_flattn[:,1], :]
123
- y_sin = xy_sin[embed_flattn[:,1], :]
124
- z_cos = z_cos[embed_flattn[:,2], :]
125
- z_sin = z_sin[embed_flattn[:,2], :]
126
-
127
- cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
128
- sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
129
-
130
- cos = cos.view(*position.shape[:-1], embed_dim)
131
- sin = sin.view(*position.shape[:-1], embed_dim)
132
- return cos, sin
133
-
134
- def apply_rotary_emb(
135
- self,
136
- x: torch.Tensor,
137
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
138
- ):
139
- cos, sin = freqs_cis # [S, D]
140
- cos, sin = cos.to(x.device), sin.to(x.device)
141
- cos = cos.unsqueeze(1)
142
- sin = sin.unsqueeze(1)
143
-
144
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
145
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
146
-
147
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
148
-
149
- return out
150
-
151
- def __call__(
152
- self,
153
- attn: Attention,
154
- hidden_states: torch.Tensor,
155
- encoder_hidden_states: Optional[torch.Tensor] = None,
156
- attention_mask: Optional[torch.Tensor] = None,
157
- position_indices: Dict = None,
158
- temb: Optional[torch.Tensor] = None,
159
- *args,
160
- **kwargs,
161
- ) -> torch.Tensor:
162
- if len(args) > 0 or kwargs.get("scale", None) is not None:
163
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
164
- deprecate("scale", "1.0.0", deprecation_message)
165
-
166
- residual = hidden_states
167
- if attn.spatial_norm is not None:
168
- hidden_states = attn.spatial_norm(hidden_states, temb)
169
-
170
- input_ndim = hidden_states.ndim
171
-
172
- if input_ndim == 4:
173
- batch_size, channel, height, width = hidden_states.shape
174
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
175
-
176
- batch_size, sequence_length, _ = (
177
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
178
- )
179
-
180
- if attention_mask is not None:
181
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
182
- # scaled_dot_product_attention expects attention_mask shape to be
183
- # (batch, heads, source_length, target_length)
184
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
185
-
186
- if attn.group_norm is not None:
187
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
188
-
189
- query = attn.to_q(hidden_states)
190
-
191
- if encoder_hidden_states is None:
192
- encoder_hidden_states = hidden_states
193
- elif attn.norm_cross:
194
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
195
-
196
- key = attn.to_k(encoder_hidden_states)
197
- value = attn.to_v(encoder_hidden_states)
198
-
199
- inner_dim = key.shape[-1]
200
- head_dim = inner_dim // attn.heads
201
-
202
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
203
-
204
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
205
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
206
-
207
- if attn.norm_q is not None:
208
- query = attn.norm_q(query)
209
- if attn.norm_k is not None:
210
- key = attn.norm_k(key)
211
-
212
- if position_indices is not None:
213
- if head_dim in position_indices:
214
- image_rotary_emb = position_indices[head_dim]
215
- else:
216
- image_rotary_emb = self.get_3d_rotary_pos_embed(position_indices['voxel_indices'], head_dim, voxel_resolution=position_indices['voxel_resolution'])
217
- position_indices[head_dim] = image_rotary_emb
218
- query = self.apply_rotary_emb(query, image_rotary_emb)
219
- key = self.apply_rotary_emb(key, image_rotary_emb)
220
-
221
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
222
- # TODO: add support for attn.scale when we move to Torch 2.1
223
- hidden_states = F.scaled_dot_product_attention(
224
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
225
- )
226
-
227
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
228
- hidden_states = hidden_states.to(query.dtype)
229
-
230
- # linear proj
231
- hidden_states = attn.to_out[0](hidden_states)
232
- # dropout
233
- hidden_states = attn.to_out[1](hidden_states)
234
-
235
- if input_ndim == 4:
236
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
237
-
238
- if attn.residual_connection:
239
- hidden_states = hidden_states + residual
240
-
241
- hidden_states = hidden_states / attn.rescale_output_factor
242
-
243
- return hidden_states
244
-
245
- class IPAttnProcessor2_0:
246
- r"""
247
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
248
- """
249
-
250
- def __init__(self, scale=0.0):
251
- if not hasattr(F, "scaled_dot_product_attention"):
252
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
253
-
254
- self.scale = scale
255
-
256
- def __call__(
257
- self,
258
- attn: Attention,
259
- hidden_states: torch.Tensor,
260
- encoder_hidden_states: Optional[torch.Tensor] = None,
261
- ip_hidden_states: Optional[torch.Tensor] = None,
262
- attention_mask: Optional[torch.Tensor] = None,
263
- temb: Optional[torch.Tensor] = None,
264
- *args,
265
- **kwargs,
266
- ) -> torch.Tensor:
267
- if len(args) > 0 or kwargs.get("scale", None) is not None:
268
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
269
- deprecate("scale", "1.0.0", deprecation_message)
270
-
271
- residual = hidden_states
272
- if attn.spatial_norm is not None:
273
- hidden_states = attn.spatial_norm(hidden_states, temb)
274
-
275
- input_ndim = hidden_states.ndim
276
-
277
- if input_ndim == 4:
278
- batch_size, channel, height, width = hidden_states.shape
279
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
280
-
281
- batch_size, sequence_length, _ = (
282
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
283
- )
284
-
285
- if attention_mask is not None:
286
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
287
- # scaled_dot_product_attention expects attention_mask shape to be
288
- # (batch, heads, source_length, target_length)
289
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
290
-
291
- if attn.group_norm is not None:
292
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
293
-
294
- query = attn.to_q(hidden_states)
295
-
296
- if encoder_hidden_states is None:
297
- encoder_hidden_states = hidden_states
298
- elif attn.norm_cross:
299
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
300
-
301
- key = attn.to_k(encoder_hidden_states)
302
- value = attn.to_v(encoder_hidden_states)
303
-
304
- inner_dim = key.shape[-1]
305
- head_dim = inner_dim // attn.heads
306
-
307
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
308
-
309
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
310
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
311
-
312
- if attn.norm_q is not None:
313
- query = attn.norm_q(query)
314
- if attn.norm_k is not None:
315
- key = attn.norm_k(key)
316
-
317
-
318
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
319
- # TODO: add support for attn.scale when we move to Torch 2.1
320
- hidden_states = F.scaled_dot_product_attention(
321
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
322
- )
323
-
324
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
325
- hidden_states = hidden_states.to(query.dtype)
326
-
327
- # for ip adapter
328
- if ip_hidden_states is not None:
329
-
330
- ip_key = attn.to_k_ip(ip_hidden_states)
331
- ip_value = attn.to_v_ip(ip_hidden_states)
332
-
333
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
334
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
335
-
336
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
337
- ip_hidden_states = F.scaled_dot_product_attention(
338
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
339
- )
340
-
341
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
342
- ip_hidden_states = ip_hidden_states.to(query.dtype)
343
-
344
- hidden_states = hidden_states + self.scale * ip_hidden_states
345
-
346
- # linear proj
347
- hidden_states = attn.to_out[0](hidden_states)
348
- # dropout
349
- hidden_states = attn.to_out[1](hidden_states)
350
-
351
- if input_ndim == 4:
352
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
353
-
354
- if attn.residual_connection:
355
- hidden_states = hidden_states + residual
356
-
357
- hidden_states = hidden_states / attn.rescale_output_factor
358
-
359
- return hidden_states
360
-
361
 
362
  class Basic2p5DTransformerBlock(torch.nn.Module):
363
- def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ipa=True, use_ma=True, use_ra=True) -> None:
364
  super().__init__()
365
  self.transformer = transformer
366
  self.layer_name = layer_name
367
- self.use_ipa = use_ipa
368
  self.use_ma = use_ma
369
  self.use_ra = use_ra
 
370
 
371
- if use_ipa:
372
- self.attn2.set_processor(IPAttnProcessor2_0())
373
- cross_attention_dim = 1024
374
- self.attn2.to_k_ip = nn.Linear(cross_attention_dim, self.dim, bias=False)
375
- self.attn2.to_v_ip = nn.Linear(cross_attention_dim, self.dim, bias=False)
376
-
377
  # multiview attn
378
  if self.use_ma:
379
  self.attn_multiview = Attention(
@@ -385,7 +73,6 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
385
  cross_attention_dim=None,
386
  upcast_attention=self.attn1.upcast_attention,
387
  out_bias=True,
388
- processor=PoseRoPEAttnProcessor2_0(),
389
  )
390
 
391
  # ref attn
@@ -400,8 +87,8 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
400
  upcast_attention=self.attn1.upcast_attention,
401
  out_bias=True,
402
  )
403
-
404
- self._initialize_attn_weights()
405
 
406
  def _initialize_attn_weights(self):
407
 
@@ -418,10 +105,6 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
418
  for param in layer.parameters():
419
  param.zero_()
420
 
421
- if self.use_ipa:
422
- self.attn2.to_k_ip.load_state_dict(self.attn2.to_k.state_dict())
423
- self.attn2.to_v_ip.load_state_dict(self.attn2.to_v.state_dict())
424
-
425
  def __getattr__(self, name: str):
426
  try:
427
  return super().__getattr__(name)
@@ -447,10 +130,16 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
447
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
448
  num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
449
  mode = cross_attention_kwargs.pop('mode', None)
 
 
 
 
 
 
 
 
 
450
  condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
451
- ip_hidden_states = cross_attention_kwargs.pop("ip_hidden_states", None)
452
- position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None)
453
- position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
454
 
455
  if self.norm_type == "ada_norm":
456
  norm_hidden_states = self.norm1(hidden_states, timestep)
@@ -470,10 +159,10 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
470
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
471
  else:
472
  raise ValueError("Incorrect norm used")
473
-
474
  if self.pos_embed is not None:
475
  norm_hidden_states = self.pos_embed(norm_hidden_states)
476
-
477
  # 1. Prepare GLIGEN inputs
478
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
479
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
@@ -484,6 +173,7 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
484
  attention_mask=attention_mask,
485
  **cross_attention_kwargs,
486
  )
 
487
  if self.norm_type == "ada_norm_zero":
488
  attn_output = gate_msa.unsqueeze(1) * attn_output
489
  elif self.norm_type == "ada_norm_single":
@@ -492,13 +182,17 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
492
  hidden_states = attn_output + hidden_states
493
  if hidden_states.ndim == 4:
494
  hidden_states = hidden_states.squeeze(1)
495
-
496
  # 1.2 Reference Attention
497
  if 'w' in mode:
498
- condition_embed_dict[self.layer_name] = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch) # B, (N L), C
499
-
500
- if 'r' in mode:
501
- condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1,num_in_batch,1,1) # B N L C
 
 
 
 
502
  condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
503
 
504
  attn_output = self.attn_refview(
@@ -507,35 +201,48 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
507
  attention_mask=None,
508
  **cross_attention_kwargs
509
  )
 
 
 
 
 
 
 
 
510
 
511
- hidden_states = attn_output + hidden_states
512
  if hidden_states.ndim == 4:
513
  hidden_states = hidden_states.squeeze(1)
514
-
515
 
516
  # 1.3 Multiview Attention
517
  if num_in_batch > 1 and self.use_ma:
518
  multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
519
- position_mask = None
520
- if position_attn_mask is not None:
521
- if multivew_hidden_states.shape[1] in position_attn_mask:
522
- position_mask = position_attn_mask[multivew_hidden_states.shape[1]]
523
- position_indices = None
524
- if position_voxel_indices is not None:
525
- if multivew_hidden_states.shape[1] in position_voxel_indices:
526
- position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
527
-
528
- attn_output = self.attn_multiview(
529
- multivew_hidden_states,
530
- encoder_hidden_states=multivew_hidden_states,
531
- attention_mask=position_mask,
532
- position_indices=position_indices,
533
- **cross_attention_kwargs
534
- )
535
 
536
- attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
- hidden_states = attn_output + hidden_states
 
 
539
  if hidden_states.ndim == 4:
540
  hidden_states = hidden_states.squeeze(1)
541
 
@@ -561,25 +268,12 @@ class Basic2p5DTransformerBlock(torch.nn.Module):
561
  if self.pos_embed is not None and self.norm_type != "ada_norm_single":
562
  norm_hidden_states = self.pos_embed(norm_hidden_states)
563
 
564
- if ip_hidden_states is not None:
565
- ip_hidden_states = ip_hidden_states.unsqueeze(1).repeat(1,num_in_batch,1,1) # B N L C
566
- ip_hidden_states = rearrange(ip_hidden_states, 'b n l c -> (b n) l c')
567
-
568
- if self.use_ipa:
569
- attn_output = self.attn2(
570
- norm_hidden_states,
571
- encoder_hidden_states=encoder_hidden_states,
572
- ip_hidden_states=ip_hidden_states,
573
- attention_mask=encoder_attention_mask,
574
- **cross_attention_kwargs,
575
- )
576
- else:
577
- attn_output = self.attn2(
578
- norm_hidden_states,
579
- encoder_hidden_states=encoder_hidden_states,
580
- attention_mask=encoder_attention_mask,
581
- **cross_attention_kwargs,
582
- )
583
 
584
  hidden_states = attn_output + hidden_states
585
 
@@ -626,8 +320,16 @@ def compute_voxel_grid_mask(position, grid_resolution=8):
626
  position[valid_mask==False] = 0
627
 
628
 
629
- position = rearrange(position, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
630
- valid_mask = rearrange(valid_mask, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
 
 
 
 
 
 
 
 
631
 
632
  grid_position = position.sum(dim=(-2, -1))
633
  count_masked = valid_mask.sum(dim=(-2, -1))
@@ -674,8 +376,16 @@ def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=
674
  valid_mask = valid_mask.expand_as(position)
675
  position[valid_mask==False] = 0
676
 
677
- position = rearrange(position, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
678
- valid_mask = rearrange(valid_mask, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution)
 
 
 
 
 
 
 
 
679
 
680
  grid_position = position.sum(dim=(-2, -1))
681
  count_masked = valid_mask.sum(dim=(-2, -1))
@@ -688,45 +398,36 @@ def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=
688
  voxel_indices = torch.round(voxel_indices).long()
689
  return voxel_indices
690
 
691
- def compute_multi_resolution_discrete_voxel_indice(position_maps, grid_resolutions=[64, 32, 16, 8], voxel_resolutions=[512, 256, 128, 64]):
 
 
 
 
692
  voxel_indices = {}
693
  with torch.no_grad():
694
  for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
695
  voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
696
  voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c')
697
  voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution}
698
- return voxel_indices
699
-
700
- class ImageProjModel(torch.nn.Module):
701
- """Projection Model"""
702
-
703
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
704
- super().__init__()
705
-
706
- self.generator = None
707
- self.cross_attention_dim = cross_attention_dim
708
- self.clip_extra_context_tokens = clip_extra_context_tokens
709
- self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
710
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
711
 
712
- def forward(self, image_embeds):
713
- embeds = image_embeds
714
- clip_extra_context_tokens = self.proj(embeds).reshape(
715
- -1, self.clip_extra_context_tokens, self.cross_attention_dim
716
- )
717
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
718
- return clip_extra_context_tokens
719
-
720
  class UNet2p5DConditionModel(torch.nn.Module):
721
  def __init__(self, unet: UNet2DConditionModel) -> None:
722
  super().__init__()
723
  self.unet = unet
724
- self.unet_dual = copy.deepcopy(unet)
725
 
726
- self.init_camera_embedding()
727
- self.init_attention(self.unet, use_ipa=True, use_ma=True, use_ra=True)
728
- self.init_attention(self.unet_dual, use_ipa=False, use_ma=False, use_ra=False)
 
 
 
 
 
 
 
729
  self.init_condition()
 
730
 
731
  @staticmethod
732
  def from_pretrained(pretrained_model_name_or_path, **kwargs):
@@ -737,170 +438,158 @@ class UNet2p5DConditionModel(torch.nn.Module):
737
  config = json.load(file)
738
  unet = UNet2DConditionModel(**config)
739
  unet = UNet2p5DConditionModel(unet)
740
-
741
- unet.unet.conv_in = torch.nn.Conv2d(
742
- 12,
743
- unet.unet.conv_in.out_channels,
744
- kernel_size=unet.unet.conv_in.kernel_size,
745
- stride=unet.unet.conv_in.stride,
746
- padding=unet.unet.conv_in.padding,
747
- dilation=unet.unet.conv_in.dilation,
748
- groups=unet.unet.conv_in.groups,
749
- bias=unet.unet.conv_in.bias is not None)
750
-
751
  unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
752
  unet.load_state_dict(unet_ckpt, strict=True)
753
  unet = unet.to(torch_dtype)
754
  return unet
755
-
756
- def init_condition(self):
757
- self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1,77,1024))
758
- self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1,77,1024))
759
 
760
- self.unet.image_proj_model = ImageProjModel(
761
- cross_attention_dim=self.unet.config.cross_attention_dim,
762
- clip_embeddings_dim=1024,
763
- )
 
 
 
 
 
 
764
 
 
 
765
 
766
  def init_camera_embedding(self):
767
- self.max_num_ref_image = 5
768
- self.max_num_gen_image = 12*3+4*2
769
 
770
- time_embed_dim = 1280
771
- self.unet.class_embedding = nn.Embedding(self.max_num_ref_image+self.max_num_gen_image, time_embed_dim)
772
- # 将嵌入层的权重初始化为全零
773
- nn.init.zeros_(self.unet.class_embedding.weight)
774
-
775
- def init_attention(self, unet, use_ipa=True, use_ma=True, use_ra=True):
 
776
 
777
  for down_block_i, down_block in enumerate(unet.down_blocks):
778
  if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
779
  for attn_i, attn in enumerate(down_block.attentions):
780
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
781
  if isinstance(transformer, BasicTransformerBlock):
782
- attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'down_{down_block_i}_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra)
 
 
 
 
783
 
784
  if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
785
  for attn_i, attn in enumerate(unet.mid_block.attentions):
786
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
787
  if isinstance(transformer, BasicTransformerBlock):
788
- attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'mid_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra)
 
 
 
 
789
 
790
  for up_block_i, up_block in enumerate(unet.up_blocks):
791
  if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
792
  for attn_i, attn in enumerate(up_block.attentions):
793
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
794
  if isinstance(transformer, BasicTransformerBlock):
795
- attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'up_{up_block_i}_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra)
796
-
 
 
 
797
 
798
  def __getattr__(self, name: str):
799
  try:
800
  return super().__getattr__(name)
801
  except AttributeError:
802
  return getattr(self.unet, name)
803
-
804
  def forward(
805
- self, sample, timestep, encoder_hidden_states, class_labels=None,
806
- *args, cross_attention_kwargs=None, down_intrablock_additional_residuals=None,
807
  down_block_res_samples=None, mid_block_res_sample=None,
808
  **cached_condition,
809
  ):
810
  B, N_gen, _, H, W = sample.shape
811
- camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image
812
- camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)')
 
 
 
 
 
 
813
  sample = [sample]
814
-
815
  if 'normal_imgs' in cached_condition:
816
  sample.append(cached_condition["normal_imgs"])
817
  if 'position_imgs' in cached_condition:
818
  sample.append(cached_condition["position_imgs"])
819
-
820
  sample = torch.cat(sample, dim=2)
 
821
  sample = rearrange(sample, 'b n c h w -> (b n) c h w')
822
 
823
  encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
824
  encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
825
-
826
-
827
- use_position_mask = False
828
- use_position_rope = True
829
-
830
- position_attn_mask = None
831
- if use_position_mask:
832
- if 'position_attn_mask' in cached_condition:
833
- position_attn_mask = cached_condition['position_attn_mask']
834
- else:
835
- if 'position_maps' in cached_condition:
836
- position_attn_mask = compute_multi_resolution_mask(cached_condition['position_maps'])
837
-
838
- position_voxel_indices = None
839
- if use_position_rope:
840
- if 'position_voxel_indices' in cached_condition:
841
- position_voxel_indices = cached_condition['position_voxel_indices']
842
- else:
843
- if 'position_maps' in cached_condition:
844
- position_voxel_indices = compute_multi_resolution_discrete_voxel_indice(cached_condition['position_maps'])
845
 
846
- if 'ip_hidden_states' in cached_condition:
847
- ip_hidden_states = cached_condition['ip_hidden_states']
848
- else:
849
- if 'clip_embeds' in cached_condition:
850
- ip_hidden_states = self.image_proj_model(cached_condition['clip_embeds'])
851
  else:
852
- ip_hidden_states = None
853
- cached_condition['ip_hidden_states'] = ip_hidden_states
854
-
855
- if 'condition_embed_dict' in cached_condition:
856
- condition_embed_dict = cached_condition['condition_embed_dict']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
  else:
858
- condition_embed_dict = {}
859
- ref_latents = cached_condition['ref_latents']
860
- N_ref = ref_latents.shape[1]
861
- camera_info_ref = cached_condition['camera_info_ref']
862
- camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
863
-
864
- #ref_latents = [ref_latents]
865
- #if 'normal_imgs' in cached_condition:
866
- # ref_latents.append(torch.zeros_like(ref_latents[0]))
867
- #if 'position_imgs' in cached_condition:
868
- # ref_latents.append(torch.zeros_like(ref_latents[0]))
869
- #ref_latents = torch.cat(ref_latents, dim=2)
870
-
871
- ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
872
 
873
- encoder_hidden_states_ref = self.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1)
874
- encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c')
875
-
876
- noisy_ref_latents = ref_latents
877
- timestep_ref = 0
878
- '''
879
- if timestep.dim()>0:
880
- timestep_ref = rearrange(timestep, '(b n) -> b n', b=B)[:,:1].repeat(1, N_ref)
881
- timestep_ref = rearrange(timestep_ref, 'b n -> (b n)')
882
- else:
883
- timestep_ref = timestep
884
- noise = torch.randn_like(noisy_ref_latents[:,:4,...])
885
- if self.training:
886
- noisy_ref_latents[:,:4,...] = self.train_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref)
887
- noisy_ref_latents[:,:4,...] = self.train_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref)
888
- else:
889
- noisy_ref_latents[:,:4,...] = self.val_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref.reshape(-1))
890
- noisy_ref_latents[:,:4,...] = self.val_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref.reshape(-1))
891
- '''
892
- self.unet_dual(
893
- noisy_ref_latents, timestep_ref,
894
- encoder_hidden_states=encoder_hidden_states_ref,
895
- #class_labels=camera_info_ref,
896
- # **kwargs
897
- return_dict=False,
898
- cross_attention_kwargs={
899
- 'mode':'w', 'num_in_batch':N_ref,
900
- 'condition_embed_dict':condition_embed_dict},
901
- )
902
- cached_condition['condition_embed_dict'] = condition_embed_dict
903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
  return self.unet(
905
  sample, timestep,
906
  encoder_hidden_states_gen, *args,
@@ -916,11 +605,6 @@ class UNet2p5DConditionModel(torch.nn.Module):
916
  if mid_block_res_sample is not None else None
917
  ),
918
  return_dict=False,
919
- cross_attention_kwargs={
920
- 'mode':'r', 'num_in_batch':N_gen,
921
- 'ip_hidden_states':ip_hidden_states,
922
- 'condition_embed_dict':condition_embed_dict,
923
- 'position_attn_mask':position_attn_mask,
924
- 'position_voxel_indices':position_voxel_indices
925
- },
926
- )
 
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
 
25
  import copy
26
  import json
27
  import os
 
40
  # "feed_forward_chunk_size" can be used to save memory
41
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
42
  raise ValueError(
43
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}"
44
+ f"has to be divisible by chunk size: {chunk_size}."
45
+ f" Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
46
  )
47
 
48
  num_chunks = hidden_states.shape[chunk_dim] // chunk_size
 
52
  )
53
  return ff_output
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  class Basic2p5DTransformerBlock(torch.nn.Module):
57
+ def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ma=True, use_ra=True, is_turbo=False) -> None:
58
  super().__init__()
59
  self.transformer = transformer
60
  self.layer_name = layer_name
 
61
  self.use_ma = use_ma
62
  self.use_ra = use_ra
63
+ self.is_turbo = is_turbo
64
 
 
 
 
 
 
 
65
  # multiview attn
66
  if self.use_ma:
67
  self.attn_multiview = Attention(
 
73
  cross_attention_dim=None,
74
  upcast_attention=self.attn1.upcast_attention,
75
  out_bias=True,
 
76
  )
77
 
78
  # ref attn
 
87
  upcast_attention=self.attn1.upcast_attention,
88
  out_bias=True,
89
  )
90
+ if self.is_turbo:
91
+ self._initialize_attn_weights()
92
 
93
  def _initialize_attn_weights(self):
94
 
 
105
  for param in layer.parameters():
106
  param.zero_()
107
 
 
 
 
 
108
  def __getattr__(self, name: str):
109
  try:
110
  return super().__getattr__(name)
 
130
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
131
  num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1)
132
  mode = cross_attention_kwargs.pop('mode', None)
133
+ if not self.is_turbo:
134
+ mva_scale = cross_attention_kwargs.pop('mva_scale', 1.0)
135
+ ref_scale = cross_attention_kwargs.pop('ref_scale', 1.0)
136
+ else:
137
+ position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None)
138
+ position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None)
139
+ mva_scale = 1.0
140
+ ref_scale = 1.0
141
+
142
  condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None)
 
 
 
143
 
144
  if self.norm_type == "ada_norm":
145
  norm_hidden_states = self.norm1(hidden_states, timestep)
 
159
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
160
  else:
161
  raise ValueError("Incorrect norm used")
162
+
163
  if self.pos_embed is not None:
164
  norm_hidden_states = self.pos_embed(norm_hidden_states)
165
+
166
  # 1. Prepare GLIGEN inputs
167
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
168
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
 
173
  attention_mask=attention_mask,
174
  **cross_attention_kwargs,
175
  )
176
+
177
  if self.norm_type == "ada_norm_zero":
178
  attn_output = gate_msa.unsqueeze(1) * attn_output
179
  elif self.norm_type == "ada_norm_single":
 
182
  hidden_states = attn_output + hidden_states
183
  if hidden_states.ndim == 4:
184
  hidden_states = hidden_states.squeeze(1)
185
+
186
  # 1.2 Reference Attention
187
  if 'w' in mode:
188
+ condition_embed_dict[self.layer_name] = rearrange(
189
+ norm_hidden_states, '(b n) l c -> b (n l) c',
190
+ n=num_in_batch
191
+ ) # B, (N L), C
192
+
193
+ if 'r' in mode and self.use_ra:
194
+ condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1, num_in_batch, 1,
195
+ 1) # B N L C
196
  condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c')
197
 
198
  attn_output = self.attn_refview(
 
201
  attention_mask=None,
202
  **cross_attention_kwargs
203
  )
204
+ if not self.is_turbo:
205
+ ref_scale_timing = ref_scale
206
+ if isinstance(ref_scale, torch.Tensor):
207
+ ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch).view(-1)
208
+ for _ in range(attn_output.ndim - 1):
209
+ ref_scale_timing = ref_scale_timing.unsqueeze(-1)
210
+
211
+ hidden_states = ref_scale_timing * attn_output + hidden_states
212
 
 
213
  if hidden_states.ndim == 4:
214
  hidden_states = hidden_states.squeeze(1)
 
215
 
216
  # 1.3 Multiview Attention
217
  if num_in_batch > 1 and self.use_ma:
218
  multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ if self.is_turbo:
221
+ position_mask = None
222
+ if position_attn_mask is not None:
223
+ if multivew_hidden_states.shape[1] in position_attn_mask:
224
+ position_mask = position_attn_mask[multivew_hidden_states.shape[1]]
225
+ position_indices = None
226
+ if position_voxel_indices is not None:
227
+ if multivew_hidden_states.shape[1] in position_voxel_indices:
228
+ position_indices = position_voxel_indices[multivew_hidden_states.shape[1]]
229
+ attn_output = self.attn_multiview(
230
+ multivew_hidden_states,
231
+ encoder_hidden_states=multivew_hidden_states,
232
+ attention_mask=position_mask,
233
+ position_indices=position_indices,
234
+ **cross_attention_kwargs
235
+ )
236
+ else:
237
+ attn_output = self.attn_multiview(
238
+ multivew_hidden_states,
239
+ encoder_hidden_states=multivew_hidden_states,
240
+ **cross_attention_kwargs
241
+ )
242
 
243
+ attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch)
244
+
245
+ hidden_states = mva_scale * attn_output + hidden_states
246
  if hidden_states.ndim == 4:
247
  hidden_states = hidden_states.squeeze(1)
248
 
 
268
  if self.pos_embed is not None and self.norm_type != "ada_norm_single":
269
  norm_hidden_states = self.pos_embed(norm_hidden_states)
270
 
271
+ attn_output = self.attn2(
272
+ norm_hidden_states,
273
+ encoder_hidden_states=encoder_hidden_states,
274
+ attention_mask=encoder_attention_mask,
275
+ **cross_attention_kwargs,
276
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  hidden_states = attn_output + hidden_states
279
 
 
320
  position[valid_mask==False] = 0
321
 
322
 
323
+ position = rearrange(
324
+ position,
325
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
326
+ num_h=grid_resolution, num_w=grid_resolution
327
+ )
328
+ valid_mask = rearrange(
329
+ valid_mask,
330
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
331
+ num_h=grid_resolution, num_w=grid_resolution
332
+ )
333
 
334
  grid_position = position.sum(dim=(-2, -1))
335
  count_masked = valid_mask.sum(dim=(-2, -1))
 
376
  valid_mask = valid_mask.expand_as(position)
377
  position[valid_mask==False] = 0
378
 
379
+ position = rearrange(
380
+ position,
381
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
382
+ num_h=grid_resolution, num_w=grid_resolution
383
+ )
384
+ valid_mask = rearrange(
385
+ valid_mask,
386
+ 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w',
387
+ num_h=grid_resolution, num_w=grid_resolution
388
+ )
389
 
390
  grid_position = position.sum(dim=(-2, -1))
391
  count_masked = valid_mask.sum(dim=(-2, -1))
 
398
  voxel_indices = torch.round(voxel_indices).long()
399
  return voxel_indices
400
 
401
+ def compute_multi_resolution_discrete_voxel_indice(
402
+ position_maps,
403
+ grid_resolutions=[64, 32, 16, 8],
404
+ voxel_resolutions=[512, 256, 128, 64]
405
+ ):
406
  voxel_indices = {}
407
  with torch.no_grad():
408
  for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions):
409
  voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution)
410
  voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c')
411
  voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution}
412
+ return voxel_indices
 
 
 
 
 
 
 
 
 
 
 
 
413
 
 
 
 
 
 
 
 
 
414
  class UNet2p5DConditionModel(torch.nn.Module):
415
  def __init__(self, unet: UNet2DConditionModel) -> None:
416
  super().__init__()
417
  self.unet = unet
 
418
 
419
+ self.use_ma = True
420
+ self.use_ra = True
421
+ self.use_camera_embedding = True
422
+ self.use_dual_stream = True
423
+ self.is_turbo = False
424
+
425
+ if self.use_dual_stream:
426
+ self.unet_dual = copy.deepcopy(unet)
427
+ self.init_attention(self.unet_dual)
428
+ self.init_attention(self.unet, use_ma=self.use_ma, use_ra=self.use_ra, is_turbo=self.is_turbo)
429
  self.init_condition()
430
+ self.init_camera_embedding()
431
 
432
  @staticmethod
433
  def from_pretrained(pretrained_model_name_or_path, **kwargs):
 
438
  config = json.load(file)
439
  unet = UNet2DConditionModel(**config)
440
  unet = UNet2p5DConditionModel(unet)
 
 
 
 
 
 
 
 
 
 
 
441
  unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
442
  unet.load_state_dict(unet_ckpt, strict=True)
443
  unet = unet.to(torch_dtype)
444
  return unet
 
 
 
 
445
 
446
+ def init_condition(self):
447
+ self.unet.conv_in = torch.nn.Conv2d(
448
+ 12,
449
+ self.unet.conv_in.out_channels,
450
+ kernel_size=self.unet.conv_in.kernel_size,
451
+ stride=self.unet.conv_in.stride,
452
+ padding=self.unet.conv_in.padding,
453
+ dilation=self.unet.conv_in.dilation,
454
+ groups=self.unet.conv_in.groups,
455
+ bias=self.unet.conv_in.bias is not None)
456
 
457
+ self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1, 77, 1024))
458
+ self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1, 77, 1024))
459
 
460
  def init_camera_embedding(self):
 
 
461
 
462
+ if self.use_camera_embedding:
463
+ time_embed_dim = 1280
464
+ self.max_num_ref_image = 5
465
+ self.max_num_gen_image = 12 * 3 + 4 * 2
466
+ self.unet.class_embedding = nn.Embedding(self.max_num_ref_image + self.max_num_gen_image, time_embed_dim)
467
+
468
+ def init_attention(self, unet, use_ma=False, use_ra=False, is_turbo=False):
469
 
470
  for down_block_i, down_block in enumerate(unet.down_blocks):
471
  if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
472
  for attn_i, attn in enumerate(down_block.attentions):
473
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
474
  if isinstance(transformer, BasicTransformerBlock):
475
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
476
+ transformer,
477
+ f'down_{down_block_i}_{attn_i}_{transformer_i}',
478
+ use_ma, use_ra, is_turbo
479
+ )
480
 
481
  if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
482
  for attn_i, attn in enumerate(unet.mid_block.attentions):
483
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
484
  if isinstance(transformer, BasicTransformerBlock):
485
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
486
+ transformer,
487
+ f'mid_{attn_i}_{transformer_i}',
488
+ use_ma, use_ra, is_turbo
489
+ )
490
 
491
  for up_block_i, up_block in enumerate(unet.up_blocks):
492
  if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention:
493
  for attn_i, attn in enumerate(up_block.attentions):
494
  for transformer_i, transformer in enumerate(attn.transformer_blocks):
495
  if isinstance(transformer, BasicTransformerBlock):
496
+ attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(
497
+ transformer,
498
+ f'up_{up_block_i}_{attn_i}_{transformer_i}',
499
+ use_ma, use_ra, is_turbo
500
+ )
501
 
502
  def __getattr__(self, name: str):
503
  try:
504
  return super().__getattr__(name)
505
  except AttributeError:
506
  return getattr(self.unet, name)
507
+
508
  def forward(
509
+ self, sample, timestep, encoder_hidden_states,
510
+ *args, down_intrablock_additional_residuals=None,
511
  down_block_res_samples=None, mid_block_res_sample=None,
512
  **cached_condition,
513
  ):
514
  B, N_gen, _, H, W = sample.shape
515
+ assert H == W
516
+
517
+ if self.use_camera_embedding:
518
+ camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image
519
+ camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)')
520
+ else:
521
+ camera_info_gen = None
522
+
523
  sample = [sample]
 
524
  if 'normal_imgs' in cached_condition:
525
  sample.append(cached_condition["normal_imgs"])
526
  if 'position_imgs' in cached_condition:
527
  sample.append(cached_condition["position_imgs"])
 
528
  sample = torch.cat(sample, dim=2)
529
+
530
  sample = rearrange(sample, 'b n c h w -> (b n) c h w')
531
 
532
  encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1)
533
  encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ if self.use_ra:
536
+ if 'condition_embed_dict' in cached_condition:
537
+ condition_embed_dict = cached_condition['condition_embed_dict']
 
 
538
  else:
539
+ condition_embed_dict = {}
540
+ ref_latents = cached_condition['ref_latents']
541
+ N_ref = ref_latents.shape[1]
542
+ if self.use_camera_embedding:
543
+ camera_info_ref = cached_condition['camera_info_ref']
544
+ camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)')
545
+ else:
546
+ camera_info_ref = None
547
+
548
+ ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w')
549
+
550
+ encoder_hidden_states_ref = self.unet.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1)
551
+ encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c')
552
+
553
+ noisy_ref_latents = ref_latents
554
+ timestep_ref = 0
555
+
556
+ if self.use_dual_stream:
557
+ unet_ref = self.unet_dual
558
+ else:
559
+ unet_ref = self.unet
560
+ unet_ref(
561
+ noisy_ref_latents, timestep_ref,
562
+ encoder_hidden_states=encoder_hidden_states_ref,
563
+ class_labels=camera_info_ref,
564
+ # **kwargs
565
+ return_dict=False,
566
+ cross_attention_kwargs={
567
+ 'mode': 'w', 'num_in_batch': N_ref,
568
+ 'condition_embed_dict': condition_embed_dict},
569
+ )
570
+ cached_condition['condition_embed_dict'] = condition_embed_dict
571
  else:
572
+ condition_embed_dict = None
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
+ mva_scale = cached_condition.get('mva_scale', 1.0)
575
+ ref_scale = cached_condition.get('ref_scale', 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
+ if self.is_turbo:
578
+ cross_attention_kwargs_ = {
579
+ 'mode': 'r', 'num_in_batch': N_gen,
580
+ 'condition_embed_dict': condition_embed_dict,
581
+ 'position_attn_mask':position_attn_mask,
582
+ 'position_voxel_indices':position_voxel_indices,
583
+ 'mva_scale': mva_scale,
584
+ 'ref_scale': ref_scale,
585
+ }
586
+ else:
587
+ cross_attention_kwargs_ = {
588
+ 'mode': 'r', 'num_in_batch': N_gen,
589
+ 'condition_embed_dict': condition_embed_dict,
590
+ 'mva_scale': mva_scale,
591
+ 'ref_scale': ref_scale,
592
+ }
593
  return self.unet(
594
  sample, timestep,
595
  encoder_hidden_states_gen, *args,
 
605
  if mid_block_res_sample is not None else None
606
  ),
607
  return_dict=False,
608
+ cross_attention_kwargs=cross_attention_kwargs_,
609
+ )
610
+