|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import copy |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Literal |
|
import diffusers |
|
from diffusers.utils import deprecate |
|
from diffusers import ( |
|
DDPMScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.models import UNet2DConditionModel |
|
from diffusers.models.attention_processor import Attention, AttnProcessor |
|
from diffusers.models.transformers.transformer_2d import BasicTransformerBlock |
|
from .attn_processor import SelfAttnProcessor2_0, RefAttnProcessor2_0, PoseRoPEAttnProcessor2_0 |
|
|
|
from transformers import AutoImageProcessor, AutoModel |
|
|
|
|
|
class Dino_v2(nn.Module): |
|
|
|
"""Wrapper for DINOv2 vision transformer (frozen weights). |
|
|
|
Provides feature extraction for reference images. |
|
|
|
Args: |
|
dino_v2_path: Custom path to DINOv2 model weights (uses default if None) |
|
""" |
|
|
|
|
|
def __init__(self, dino_v2_path): |
|
super(Dino_v2, self).__init__() |
|
self.dino_processor = AutoImageProcessor.from_pretrained(dino_v2_path) |
|
self.dino_v2 = AutoModel.from_pretrained(dino_v2_path) |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
self.dino_v2.eval() |
|
|
|
def forward(self, images): |
|
|
|
"""Processes input images through DINOv2 ViT. |
|
|
|
Handles both tensor input (B, N, C, H, W) and PIL image lists. |
|
Extracts patch embeddings and flattens spatial dimensions. |
|
|
|
Returns: |
|
torch.Tensor: Feature vectors [B, N*(num_patches), feature_dim] |
|
""" |
|
|
|
if isinstance(images, torch.Tensor): |
|
batch_size = images.shape[0] |
|
dino_proceesed_images = self.dino_processor( |
|
images=rearrange(images, "b n c h w -> (b n) c h w"), return_tensors="pt", do_rescale=False |
|
).pixel_values |
|
else: |
|
batch_size = 1 |
|
dino_proceesed_images = self.dino_processor(images=images, return_tensors="pt").pixel_values |
|
dino_proceesed_images = torch.stack( |
|
[torch.from_numpy(np.array(image)) for image in dino_proceesed_images], dim=0 |
|
) |
|
dino_param = next(self.dino_v2.parameters()) |
|
dino_proceesed_images = dino_proceesed_images.to(dino_param) |
|
dino_hidden_states = self.dino_v2(dino_proceesed_images)[0] |
|
dino_hidden_states = rearrange(dino_hidden_states.to(dino_param), "(b n) l c -> b (n l) c", b=batch_size) |
|
|
|
return dino_hidden_states |
|
|
|
|
|
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): |
|
|
|
|
|
"""Memory-efficient feedforward execution via chunking. |
|
|
|
Divides input along specified dimension for sequential processing. |
|
|
|
Args: |
|
ff: Feedforward module to apply |
|
hidden_states: Input tensor |
|
chunk_dim: Dimension to split |
|
chunk_size: Size of each chunk |
|
|
|
Returns: |
|
torch.Tensor: Reassembled output tensor |
|
""" |
|
|
|
if hidden_states.shape[chunk_dim] % chunk_size != 0: |
|
raise ValueError( |
|
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}" |
|
f"has to be divisible by chunk size: {chunk_size}." |
|
"Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
|
) |
|
|
|
num_chunks = hidden_states.shape[chunk_dim] // chunk_size |
|
ff_output = torch.cat( |
|
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], |
|
dim=chunk_dim, |
|
) |
|
return ff_output |
|
|
|
|
|
@torch.no_grad() |
|
def compute_voxel_grid_mask(position, grid_resolution=8): |
|
|
|
"""Generates view-to-view attention mask based on 3D position similarity. |
|
|
|
Uses voxel grid downsampling to determine spatially adjacent regions. |
|
Mask indicates where features should interact across different views. |
|
|
|
Args: |
|
position: Position maps [B, N, 3, H, W] (normalized 0-1) |
|
grid_resolution: Spatial reduction factor |
|
|
|
Returns: |
|
torch.Tensor: Attention mask [B, N*grid_res**2, N*grid_res**2] |
|
""" |
|
|
|
position = position.half() |
|
B, N, _, H, W = position.shape |
|
assert H % grid_resolution == 0 and W % grid_resolution == 0 |
|
|
|
valid_mask = (position != 1).all(dim=2, keepdim=True) |
|
valid_mask = valid_mask.expand_as(position) |
|
position[valid_mask == False] = 0 |
|
|
|
position = rearrange( |
|
position, |
|
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", |
|
num_h=grid_resolution, |
|
num_w=grid_resolution, |
|
) |
|
valid_mask = rearrange( |
|
valid_mask, |
|
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", |
|
num_h=grid_resolution, |
|
num_w=grid_resolution, |
|
) |
|
|
|
grid_position = position.sum(dim=(-2, -1)) |
|
count_masked = valid_mask.sum(dim=(-2, -1)) |
|
|
|
grid_position = grid_position / count_masked.clamp(min=1) |
|
grid_position[count_masked < 5] = 0 |
|
|
|
grid_position = grid_position.permute(0, 1, 4, 2, 3) |
|
grid_position = rearrange(grid_position, "b n c h w -> b n (h w) c") |
|
|
|
grid_position_expanded_1 = grid_position.unsqueeze(2).unsqueeze(4) |
|
grid_position_expanded_2 = grid_position.unsqueeze(1).unsqueeze(3) |
|
|
|
|
|
distances = torch.norm(grid_position_expanded_1 - grid_position_expanded_2, dim=-1) |
|
|
|
weights = distances |
|
grid_distance = 1.73 / grid_resolution |
|
weights = weights < grid_distance |
|
|
|
return weights |
|
|
|
|
|
def compute_multi_resolution_mask(position_maps, grid_resolutions=[32, 16, 8]): |
|
|
|
"""Generates attention masks at multiple spatial resolutions. |
|
|
|
Creates pyramid of position-based masks for hierarchical attention. |
|
|
|
Args: |
|
position_maps: Position maps [B, N, 3, H, W] |
|
grid_resolutions: List of downsampling factors |
|
|
|
Returns: |
|
dict: Resolution-specific masks keyed by flattened dimension size |
|
""" |
|
|
|
position_attn_mask = {} |
|
with torch.no_grad(): |
|
for grid_resolution in grid_resolutions: |
|
position_mask = compute_voxel_grid_mask(position_maps, grid_resolution) |
|
position_mask = rearrange(position_mask, "b ni nj li lj -> b (ni li) (nj lj)") |
|
position_attn_mask[position_mask.shape[1]] = position_mask |
|
return position_attn_mask |
|
|
|
|
|
@torch.no_grad() |
|
def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=128): |
|
|
|
"""Quantizes position maps to discrete voxel indices. |
|
|
|
Creates sparse 3D coordinate representations for efficient hashing. |
|
|
|
Args: |
|
position: Position maps [B, N, 3, H, W] |
|
grid_resolution: Spatial downsampling factor |
|
voxel_resolution: Quantization resolution |
|
|
|
Returns: |
|
torch.Tensor: Voxel indices [B, N, grid_res, grid_res, 3] |
|
""" |
|
|
|
position = position.half() |
|
B, N, _, H, W = position.shape |
|
assert H % grid_resolution == 0 and W % grid_resolution == 0 |
|
|
|
valid_mask = (position != 1).all(dim=2, keepdim=True) |
|
valid_mask = valid_mask.expand_as(position) |
|
position[valid_mask == False] = 0 |
|
|
|
position = rearrange( |
|
position, |
|
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", |
|
num_h=grid_resolution, |
|
num_w=grid_resolution, |
|
) |
|
valid_mask = rearrange( |
|
valid_mask, |
|
"b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", |
|
num_h=grid_resolution, |
|
num_w=grid_resolution, |
|
) |
|
|
|
grid_position = position.sum(dim=(-2, -1)) |
|
count_masked = valid_mask.sum(dim=(-2, -1)) |
|
|
|
grid_position = grid_position / count_masked.clamp(min=1) |
|
voxel_mask_thres = (H // grid_resolution) * (W // grid_resolution) // (4 * 4) |
|
grid_position[count_masked < voxel_mask_thres] = 0 |
|
|
|
grid_position = grid_position.permute(0, 1, 4, 2, 3).clamp(0, 1) |
|
voxel_indices = grid_position * (voxel_resolution - 1) |
|
voxel_indices = torch.round(voxel_indices).long() |
|
return voxel_indices |
|
|
|
|
|
def calc_multires_voxel_idxs(position_maps, grid_resolutions=[64, 32, 16, 8], voxel_resolutions=[512, 256, 128, 64]): |
|
|
|
"""Generates multi-resolution voxel indices for position encoding. |
|
|
|
Creates pyramid of quantized position representations. |
|
|
|
Args: |
|
position_maps: Input position maps |
|
grid_resolutions: Spatial resolution levels |
|
voxel_resolutions: Quantization levels |
|
|
|
Returns: |
|
dict: Voxel indices keyed by flattened dimension size, with resolution metadata |
|
""" |
|
|
|
voxel_indices = {} |
|
with torch.no_grad(): |
|
for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions): |
|
voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution) |
|
voxel_indice = rearrange(voxel_indice, "b n c h w -> b (n h w) c") |
|
voxel_indices[voxel_indice.shape[1]] = {"voxel_indices": voxel_indice, "voxel_resolution": voxel_resolution} |
|
return voxel_indices |
|
|
|
|
|
class Basic2p5DTransformerBlock(torch.nn.Module): |
|
|
|
|
|
"""Enhanced transformer block for multiview 2.5D image generation. |
|
|
|
Extends standard transformer blocks with: |
|
- Material-specific attention (MDA) |
|
- Multiview attention (MA) |
|
- Reference attention (RA) |
|
- DINO feature integration |
|
|
|
Args: |
|
transformer: Base transformer block |
|
layer_name: Identifier for layer |
|
use_ma: Enable multiview attention |
|
use_ra: Enable reference attention |
|
use_mda: Enable material-aware attention |
|
use_dino: Enable DINO feature integration |
|
pbr_setting: List of PBR materials |
|
""" |
|
|
|
def __init__( |
|
self, |
|
transformer: BasicTransformerBlock, |
|
layer_name, |
|
use_ma=True, |
|
use_ra=True, |
|
use_mda=True, |
|
use_dino=True, |
|
pbr_setting=None, |
|
) -> None: |
|
|
|
""" |
|
Initialization: |
|
1. Material-Dimension Attention (MDA): |
|
- Processes each PBR material with separate projection weights |
|
- Uses custom SelfAttnProcessor2_0 with material awareness |
|
|
|
2. Multiview Attention (MA): |
|
- Adds cross-view attention with PoseRoPE |
|
- Initialized as zero-initialized residual pathway |
|
|
|
3. Reference Attention (RA): |
|
- Conditions on reference view features |
|
- Uses RefAttnProcessor2_0 for material-specific conditioning |
|
|
|
4. DINO Attention: |
|
- Incorporates DINO-ViT features |
|
- Initialized as zero-initialized residual pathway |
|
""" |
|
|
|
super().__init__() |
|
self.transformer = transformer |
|
self.layer_name = layer_name |
|
self.use_ma = use_ma |
|
self.use_ra = use_ra |
|
self.use_mda = use_mda |
|
self.use_dino = use_dino |
|
self.pbr_setting = pbr_setting |
|
|
|
if self.use_mda: |
|
self.attn1.set_processor( |
|
SelfAttnProcessor2_0( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=None, |
|
upcast_attention=self.attn1.upcast_attention, |
|
out_bias=True, |
|
pbr_setting=self.pbr_setting, |
|
) |
|
) |
|
|
|
|
|
if self.use_ma: |
|
self.attn_multiview = Attention( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=None, |
|
upcast_attention=self.attn1.upcast_attention, |
|
out_bias=True, |
|
processor=PoseRoPEAttnProcessor2_0(), |
|
) |
|
|
|
|
|
if self.use_ra: |
|
self.attn_refview = Attention( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=None, |
|
upcast_attention=self.attn1.upcast_attention, |
|
out_bias=True, |
|
processor=RefAttnProcessor2_0( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=None, |
|
upcast_attention=self.attn1.upcast_attention, |
|
out_bias=True, |
|
pbr_setting=self.pbr_setting, |
|
), |
|
) |
|
|
|
|
|
if self.use_dino: |
|
self.attn_dino = Attention( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=self.cross_attention_dim, |
|
upcast_attention=self.attn2.upcast_attention, |
|
out_bias=True, |
|
) |
|
|
|
self._initialize_attn_weights() |
|
|
|
def _initialize_attn_weights(self): |
|
|
|
"""Initializes specialized attention heads with base weights. |
|
|
|
Uses weight sharing strategy: |
|
- Copies base transformer weights to specialized heads |
|
- Initializes newly-added parameters to zero |
|
""" |
|
|
|
if self.use_mda: |
|
for token in self.pbr_setting: |
|
if token == "albedo": |
|
continue |
|
getattr(self.attn1.processor, f"to_q_{token}").load_state_dict(self.attn1.to_q.state_dict()) |
|
getattr(self.attn1.processor, f"to_k_{token}").load_state_dict(self.attn1.to_k.state_dict()) |
|
getattr(self.attn1.processor, f"to_v_{token}").load_state_dict(self.attn1.to_v.state_dict()) |
|
getattr(self.attn1.processor, f"to_out_{token}").load_state_dict(self.attn1.to_out.state_dict()) |
|
|
|
if self.use_ma: |
|
self.attn_multiview.load_state_dict(self.attn1.state_dict(), strict=False) |
|
with torch.no_grad(): |
|
for layer in self.attn_multiview.to_out: |
|
for param in layer.parameters(): |
|
param.zero_() |
|
|
|
if self.use_ra: |
|
self.attn_refview.load_state_dict(self.attn1.state_dict(), strict=False) |
|
for token in self.pbr_setting: |
|
if token == "albedo": |
|
continue |
|
getattr(self.attn_refview.processor, f"to_v_{token}").load_state_dict( |
|
self.attn_refview.to_q.state_dict() |
|
) |
|
getattr(self.attn_refview.processor, f"to_out_{token}").load_state_dict( |
|
self.attn_refview.to_out.state_dict() |
|
) |
|
with torch.no_grad(): |
|
for layer in self.attn_refview.to_out: |
|
for param in layer.parameters(): |
|
param.zero_() |
|
for token in self.pbr_setting: |
|
if token == "albedo": |
|
continue |
|
for layer in getattr(self.attn_refview.processor, f"to_out_{token}"): |
|
for param in layer.parameters(): |
|
param.zero_() |
|
|
|
if self.use_dino: |
|
self.attn_dino.load_state_dict(self.attn2.state_dict(), strict=False) |
|
with torch.no_grad(): |
|
for layer in self.attn_dino.to_out: |
|
for param in layer.parameters(): |
|
param.zero_() |
|
|
|
if self.use_dino: |
|
self.attn_dino.load_state_dict(self.attn2.state_dict(), strict=False) |
|
with torch.no_grad(): |
|
for layer in self.attn_dino.to_out: |
|
for param in layer.parameters(): |
|
param.zero_() |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.transformer, name) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
|
|
"""Forward pass with multi-mechanism attention. |
|
|
|
Processing stages: |
|
1. Material-aware self-attention (MDA) |
|
2. Reference attention (RA) |
|
3. Multiview attention (MA) with position-aware attention |
|
4. Text conditioning (base attention) |
|
5. DINO feature conditioning (optional) |
|
6. Position-aware conditioning |
|
7. Feed-forward network |
|
|
|
Args: |
|
hidden_states: Input features [B * N_materials * N_views, Seq_len, Feat_dim] |
|
See base transformer for other parameters |
|
|
|
Returns: |
|
torch.Tensor: Output features |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} |
|
num_in_batch = cross_attention_kwargs.pop("num_in_batch", 1) |
|
mode = cross_attention_kwargs.pop("mode", None) |
|
mva_scale = cross_attention_kwargs.pop("mva_scale", 1.0) |
|
ref_scale = cross_attention_kwargs.pop("ref_scale", 1.0) |
|
condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None) |
|
dino_hidden_states = cross_attention_kwargs.pop("dino_hidden_states", None) |
|
position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None) |
|
N_pbr = len(self.pbr_setting) if self.pbr_setting is not None else 1 |
|
|
|
if self.norm_type == "ada_norm": |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.norm_type == "ada_norm_zero": |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
|
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
|
) |
|
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
elif self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
elif self.norm_type == "ada_norm_single": |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
|
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
|
).chunk(6, dim=1) |
|
norm_hidden_states = self.norm1(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
else: |
|
raise ValueError("Incorrect norm used") |
|
|
|
if self.pos_embed is not None: |
|
norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
|
|
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} |
|
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) |
|
|
|
if self.use_mda: |
|
mda_norm_hidden_states = rearrange( |
|
norm_hidden_states, "(b n_pbr n) l c -> b n_pbr n l c", n=num_in_batch, n_pbr=N_pbr |
|
) |
|
attn_output = self.attn1( |
|
mda_norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
attn_output = rearrange(attn_output, "b n_pbr n l c -> (b n_pbr n) l c") |
|
else: |
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
elif self.norm_type == "ada_norm_single": |
|
attn_output = gate_msa * attn_output |
|
|
|
hidden_states = attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
if "w" in mode: |
|
condition_embed_dict[self.layer_name] = rearrange( |
|
norm_hidden_states, "(b n) l c -> b (n l) c", n=num_in_batch |
|
) |
|
|
|
if "r" in mode and self.use_ra: |
|
condition_embed = condition_embed_dict[self.layer_name] |
|
|
|
|
|
ref_norm_hidden_states = rearrange( |
|
norm_hidden_states, "(b n_pbr n) l c -> b n_pbr (n l) c", n=num_in_batch, n_pbr=N_pbr |
|
)[:, 0, ...] |
|
|
|
attn_output = self.attn_refview( |
|
ref_norm_hidden_states, |
|
encoder_hidden_states=condition_embed, |
|
attention_mask=None, |
|
**cross_attention_kwargs, |
|
) |
|
attn_output = rearrange(attn_output, "b n_pbr (n l) c -> (b n_pbr n) l c", n=num_in_batch, n_pbr=N_pbr) |
|
|
|
ref_scale_timing = ref_scale |
|
if isinstance(ref_scale, torch.Tensor): |
|
ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch * N_pbr).view(-1) |
|
for _ in range(attn_output.ndim - 1): |
|
ref_scale_timing = ref_scale_timing.unsqueeze(-1) |
|
hidden_states = ref_scale_timing * attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
if num_in_batch > 1 and self.use_ma: |
|
multivew_hidden_states = rearrange( |
|
norm_hidden_states, "(b n_pbr n) l c -> (b n_pbr) (n l) c", n_pbr=N_pbr, n=num_in_batch |
|
) |
|
position_indices = None |
|
if position_voxel_indices is not None: |
|
if multivew_hidden_states.shape[1] in position_voxel_indices: |
|
position_indices = position_voxel_indices[multivew_hidden_states.shape[1]] |
|
|
|
attn_output = self.attn_multiview( |
|
multivew_hidden_states, |
|
encoder_hidden_states=multivew_hidden_states, |
|
position_indices=position_indices, |
|
n_pbrs=N_pbr, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
attn_output = rearrange(attn_output, "(b n_pbr) (n l) c -> (b n_pbr n) l c", n_pbr=N_pbr, n=num_in_batch) |
|
|
|
hidden_states = mva_scale * attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
if gligen_kwargs is not None: |
|
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) |
|
|
|
|
|
if self.attn2 is not None: |
|
if self.norm_type == "ada_norm": |
|
norm_hidden_states = self.norm2(hidden_states, timestep) |
|
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: |
|
norm_hidden_states = self.norm2(hidden_states) |
|
elif self.norm_type == "ada_norm_single": |
|
|
|
|
|
norm_hidden_states = hidden_states |
|
elif self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
else: |
|
raise ValueError("Incorrect norm") |
|
|
|
if self.pos_embed is not None and self.norm_type != "ada_norm_single": |
|
norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
if self.use_dino: |
|
dino_hidden_states = dino_hidden_states.unsqueeze(1).repeat(1, N_pbr * num_in_batch, 1, 1) |
|
dino_hidden_states = rearrange(dino_hidden_states, "b n l c -> (b n) l c") |
|
attn_output = self.attn_dino( |
|
norm_hidden_states, |
|
encoder_hidden_states=dino_hidden_states, |
|
attention_mask=None, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
|
|
if self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
elif not self.norm_type == "ada_norm_single": |
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
if self.norm_type == "ada_norm_single": |
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
elif self.norm_type == "ada_norm_single": |
|
ff_output = gate_mlp * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
return hidden_states |
|
|
|
|
|
class ImageProjModel(torch.nn.Module): |
|
|
|
"""Projects image embeddings into cross-attention space. |
|
|
|
Transforms CLIP embeddings into additional context tokens for conditioning. |
|
|
|
Args: |
|
cross_attention_dim: Dimension of attention space |
|
clip_embeddings_dim: Dimension of input CLIP embeddings |
|
clip_extra_context_tokens: Number of context tokens to generate |
|
""" |
|
|
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): |
|
super().__init__() |
|
|
|
self.generator = None |
|
self.cross_attention_dim = cross_attention_dim |
|
self.clip_extra_context_tokens = clip_extra_context_tokens |
|
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) |
|
self.norm = torch.nn.LayerNorm(cross_attention_dim) |
|
|
|
def forward(self, image_embeds): |
|
|
|
"""Projects image embeddings to cross-attention context tokens. |
|
|
|
Args: |
|
image_embeds: Input embeddings [B, N, C] or [B, C] |
|
|
|
Returns: |
|
torch.Tensor: Context tokens [B, N*clip_extra_context_tokens, cross_attention_dim] |
|
""" |
|
|
|
embeds = image_embeds |
|
num_token = 1 |
|
if embeds.dim() == 3: |
|
num_token = embeds.shape[1] |
|
embeds = rearrange(embeds, "b n c -> (b n) c") |
|
|
|
clip_extra_context_tokens = self.proj(embeds).reshape( |
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim |
|
) |
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) |
|
|
|
clip_extra_context_tokens = rearrange(clip_extra_context_tokens, "(b nt) n c -> b (nt n) c", nt=num_token) |
|
|
|
return clip_extra_context_tokens |
|
|
|
|
|
class UNet2p5DConditionModel(torch.nn.Module): |
|
|
|
"""2.5D UNet extension for multiview PBR generation. |
|
|
|
Enhances standard 2D UNet with: |
|
- Multiview attention mechanisms |
|
- Material-aware processing |
|
- Position-aware conditioning |
|
- Dual-stream reference processing |
|
|
|
Args: |
|
unet: Base 2D UNet model |
|
train_sched: Training scheduler (DDPM) |
|
val_sched: Validation scheduler (EulerAncestral) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
unet: UNet2DConditionModel, |
|
train_sched: DDPMScheduler = None, |
|
val_sched: EulerAncestralDiscreteScheduler = None, |
|
) -> None: |
|
super().__init__() |
|
self.unet = unet |
|
self.train_sched = train_sched |
|
self.val_sched = val_sched |
|
|
|
self.use_ma = True |
|
self.use_ra = True |
|
self.use_mda = True |
|
self.use_dino = True |
|
self.use_position_rope = True |
|
self.use_learned_text_clip = True |
|
self.use_dual_stream = True |
|
self.pbr_setting = ["albedo", "mr"] |
|
self.pbr_token_channels = 77 |
|
|
|
if self.use_dual_stream and self.use_ra: |
|
self.unet_dual = copy.deepcopy(unet) |
|
self.init_attention(self.unet_dual) |
|
|
|
self.init_attention( |
|
self.unet, |
|
use_ma=self.use_ma, |
|
use_ra=self.use_ra, |
|
use_dino=self.use_dino, |
|
use_mda=self.use_mda, |
|
pbr_setting=self.pbr_setting, |
|
) |
|
self.init_condition(use_dino=self.use_dino) |
|
|
|
@staticmethod |
|
def from_pretrained(pretrained_model_name_or_path, **kwargs): |
|
torch_dtype = kwargs.pop("torch_dtype", torch.float32) |
|
config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
|
unet_ckpt_path = os.path.join(pretrained_model_name_or_path, "diffusion_pytorch_model.bin") |
|
with open(config_path, "r", encoding="utf-8") as file: |
|
config = json.load(file) |
|
unet = UNet2DConditionModel(**config) |
|
unet_2p5d = UNet2p5DConditionModel(unet) |
|
unet_2p5d.unet.conv_in = torch.nn.Conv2d( |
|
12, |
|
unet.conv_in.out_channels, |
|
kernel_size=unet.conv_in.kernel_size, |
|
stride=unet.conv_in.stride, |
|
padding=unet.conv_in.padding, |
|
dilation=unet.conv_in.dilation, |
|
groups=unet.conv_in.groups, |
|
bias=unet.conv_in.bias is not None, |
|
) |
|
unet_ckpt = torch.load(unet_ckpt_path, map_location="cpu", weights_only=True) |
|
unet_2p5d.load_state_dict(unet_ckpt, strict=True) |
|
unet_2p5d = unet_2p5d.to(torch_dtype) |
|
return unet_2p5d |
|
|
|
def init_condition(self, use_dino): |
|
|
|
"""Initializes conditioning mechanisms for multiview PBR generation. |
|
|
|
Sets up: |
|
1. Learned text embeddings: Material-specific tokens (albedo, mr) initialized to zeros |
|
2. DINO projector: Model to process DINO-ViT features for cross-attention |
|
|
|
Args: |
|
use_dino: Flag to enable DINO feature integration |
|
""" |
|
|
|
if self.use_learned_text_clip: |
|
for token in self.pbr_setting: |
|
self.unet.register_parameter( |
|
f"learned_text_clip_{token}", nn.Parameter(torch.zeros(self.pbr_token_channels, 1024)) |
|
) |
|
self.unet.learned_text_clip_ref = nn.Parameter(torch.zeros(self.pbr_token_channels, 1024)) |
|
|
|
if use_dino: |
|
self.unet.image_proj_model_dino = ImageProjModel( |
|
cross_attention_dim=self.unet.config.cross_attention_dim, |
|
clip_embeddings_dim=1536, |
|
clip_extra_context_tokens=4, |
|
) |
|
|
|
def init_attention(self, unet, use_ma=False, use_ra=False, use_mda=False, use_dino=False, pbr_setting=None): |
|
|
|
"""Recursively replaces standard transformers with enhanced 2.5D blocks. |
|
|
|
Processes UNet architecture: |
|
1. Downsampling blocks: Replaces transformers in attention layers |
|
2. Middle block: Upgrades central transformers |
|
3. Upsampling blocks: Modifies decoder transformers |
|
|
|
Args: |
|
unet: UNet model to enhance |
|
use_ma: Enable multiview attention |
|
use_ra: Enable reference attention |
|
use_mda: Enable material-specific attention |
|
use_dino: Enable DINO feature integration |
|
pbr_setting: List of PBR materials |
|
""" |
|
|
|
for down_block_i, down_block in enumerate(unet.down_blocks): |
|
if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention: |
|
for attn_i, attn in enumerate(down_block.attentions): |
|
for transformer_i, transformer in enumerate(attn.transformer_blocks): |
|
if isinstance(transformer, BasicTransformerBlock): |
|
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( |
|
transformer, |
|
f"down_{down_block_i}_{attn_i}_{transformer_i}", |
|
use_ma, |
|
use_ra, |
|
use_mda, |
|
use_dino, |
|
pbr_setting, |
|
) |
|
|
|
if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention: |
|
for attn_i, attn in enumerate(unet.mid_block.attentions): |
|
for transformer_i, transformer in enumerate(attn.transformer_blocks): |
|
if isinstance(transformer, BasicTransformerBlock): |
|
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( |
|
transformer, f"mid_{attn_i}_{transformer_i}", use_ma, use_ra, use_mda, use_dino, pbr_setting |
|
) |
|
|
|
for up_block_i, up_block in enumerate(unet.up_blocks): |
|
if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention: |
|
for attn_i, attn in enumerate(up_block.attentions): |
|
for transformer_i, transformer in enumerate(attn.transformer_blocks): |
|
if isinstance(transformer, BasicTransformerBlock): |
|
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( |
|
transformer, |
|
f"up_{up_block_i}_{attn_i}_{transformer_i}", |
|
use_ma, |
|
use_ra, |
|
use_mda, |
|
use_dino, |
|
pbr_setting, |
|
) |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.unet, name) |
|
|
|
def forward( |
|
self, |
|
sample, |
|
timestep, |
|
encoder_hidden_states, |
|
*args, |
|
added_cond_kwargs=None, |
|
cross_attention_kwargs=None, |
|
down_intrablock_additional_residuals=None, |
|
down_block_res_samples=None, |
|
mid_block_res_sample=None, |
|
**cached_condition, |
|
): |
|
|
|
"""Forward pass with multiview/material conditioning. |
|
|
|
Key stages: |
|
1. Input preparation (concat normal/position maps) |
|
2. Reference feature extraction (dual-stream) |
|
3. Position encoding (voxel indices) |
|
4. DINO feature projection |
|
5. Main UNet processing with attention conditioning |
|
|
|
Args: |
|
sample: Input latents [B, N_pbr, N_gen, C, H, W] |
|
cached_condition: Dictionary containing: |
|
- embeds_normal: Normal map embeddings |
|
- embeds_position: Position map embeddings |
|
- ref_latents: Reference image latents |
|
- dino_hidden_states: DINO features |
|
- position_maps: 3D position maps |
|
- mva_scale: Multiview attention scale |
|
- ref_scale: Reference attention scale |
|
|
|
Returns: |
|
torch.Tensor: Output features |
|
""" |
|
|
|
B, N_pbr, N_gen, _, H, W = sample.shape |
|
assert H == W |
|
|
|
if "cache" not in cached_condition: |
|
cached_condition["cache"] = {} |
|
|
|
sample = [sample] |
|
if "embeds_normal" in cached_condition: |
|
sample.append(cached_condition["embeds_normal"].unsqueeze(1).repeat(1, N_pbr, 1, 1, 1, 1)) |
|
if "embeds_position" in cached_condition: |
|
sample.append(cached_condition["embeds_position"].unsqueeze(1).repeat(1, N_pbr, 1, 1, 1, 1)) |
|
sample = torch.cat(sample, dim=-3) |
|
|
|
sample = rearrange(sample, "b n_pbr n c h w -> (b n_pbr n) c h w") |
|
|
|
encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(-3).repeat(1, 1, N_gen, 1, 1) |
|
encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, "b n_pbr n l c -> (b n_pbr n) l c") |
|
|
|
if added_cond_kwargs is not None: |
|
text_embeds_gen = added_cond_kwargs["text_embeds"].unsqueeze(1).repeat(1, N_gen, 1) |
|
text_embeds_gen = rearrange(text_embeds_gen, "b n c -> (b n) c") |
|
time_ids_gen = added_cond_kwargs["time_ids"].unsqueeze(1).repeat(1, N_gen, 1) |
|
time_ids_gen = rearrange(time_ids_gen, "b n c -> (b n) c") |
|
added_cond_kwargs_gen = {"text_embeds": text_embeds_gen, "time_ids": time_ids_gen} |
|
else: |
|
added_cond_kwargs_gen = None |
|
|
|
if self.use_position_rope: |
|
if "position_voxel_indices" in cached_condition["cache"]: |
|
position_voxel_indices = cached_condition["cache"]["position_voxel_indices"] |
|
else: |
|
if "position_maps" in cached_condition: |
|
position_voxel_indices = calc_multires_voxel_idxs( |
|
cached_condition["position_maps"], |
|
grid_resolutions=[H, H // 2, H // 4, H // 8], |
|
voxel_resolutions=[H * 8, H * 4, H * 2, H], |
|
) |
|
cached_condition["cache"]["position_voxel_indices"] = position_voxel_indices |
|
else: |
|
position_voxel_indices = None |
|
|
|
if self.use_dino: |
|
if "dino_hidden_states_proj" in cached_condition["cache"]: |
|
dino_hidden_states = cached_condition["cache"]["dino_hidden_states_proj"] |
|
else: |
|
assert "dino_hidden_states" in cached_condition |
|
dino_hidden_states = cached_condition["dino_hidden_states"] |
|
dino_hidden_states = self.image_proj_model_dino(dino_hidden_states) |
|
cached_condition["cache"]["dino_hidden_states_proj"] = dino_hidden_states |
|
else: |
|
dino_hidden_states = None |
|
|
|
if self.use_ra: |
|
if "condition_embed_dict" in cached_condition["cache"]: |
|
condition_embed_dict = cached_condition["cache"]["condition_embed_dict"] |
|
else: |
|
condition_embed_dict = {} |
|
ref_latents = cached_condition["ref_latents"] |
|
N_ref = ref_latents.shape[1] |
|
|
|
if not self.use_dual_stream: |
|
ref_latents = [ref_latents] |
|
if "embeds_normal" in cached_condition: |
|
ref_latents.append(torch.zeros_like(ref_latents[0])) |
|
if "embeds_position" in cached_condition: |
|
ref_latents.append(torch.zeros_like(ref_latents[0])) |
|
ref_latents = torch.cat(ref_latents, dim=2) |
|
|
|
ref_latents = rearrange(ref_latents, "b n c h w -> (b n) c h w") |
|
|
|
encoder_hidden_states_ref = self.unet.learned_text_clip_ref.repeat(B, N_ref, 1, 1) |
|
|
|
encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, "b n l c -> (b n) l c") |
|
|
|
if added_cond_kwargs is not None: |
|
text_embeds_ref = added_cond_kwargs["text_embeds"].unsqueeze(1).repeat(1, N_ref, 1) |
|
text_embeds_ref = rearrange(text_embeds_ref, "b n c -> (b n) c") |
|
time_ids_ref = added_cond_kwargs["time_ids"].unsqueeze(1).repeat(1, N_ref, 1) |
|
time_ids_ref = rearrange(time_ids_ref, "b n c -> (b n) c") |
|
added_cond_kwargs_ref = { |
|
"text_embeds": text_embeds_ref, |
|
"time_ids": time_ids_ref, |
|
} |
|
else: |
|
added_cond_kwargs_ref = None |
|
|
|
noisy_ref_latents = ref_latents |
|
timestep_ref = 0 |
|
if self.use_dual_stream: |
|
unet_ref = self.unet_dual |
|
else: |
|
unet_ref = self.unet |
|
unet_ref( |
|
noisy_ref_latents, |
|
timestep_ref, |
|
encoder_hidden_states=encoder_hidden_states_ref, |
|
class_labels=None, |
|
added_cond_kwargs=added_cond_kwargs_ref, |
|
|
|
return_dict=False, |
|
cross_attention_kwargs={ |
|
"mode": "w", |
|
"num_in_batch": N_ref, |
|
"condition_embed_dict": condition_embed_dict, |
|
}, |
|
) |
|
cached_condition["cache"]["condition_embed_dict"] = condition_embed_dict |
|
else: |
|
condition_embed_dict = None |
|
|
|
mva_scale = cached_condition.get("mva_scale", 1.0) |
|
ref_scale = cached_condition.get("ref_scale", 1.0) |
|
|
|
return self.unet( |
|
sample, |
|
timestep, |
|
encoder_hidden_states_gen, |
|
*args, |
|
class_labels=None, |
|
added_cond_kwargs=added_cond_kwargs_gen, |
|
down_intrablock_additional_residuals=( |
|
[sample.to(dtype=self.unet.dtype) for sample in down_intrablock_additional_residuals] |
|
if down_intrablock_additional_residuals is not None |
|
else None |
|
), |
|
down_block_additional_residuals=( |
|
[sample.to(dtype=self.unet.dtype) for sample in down_block_res_samples] |
|
if down_block_res_samples is not None |
|
else None |
|
), |
|
mid_block_additional_residual=( |
|
mid_block_res_sample.to(dtype=self.unet.dtype) if mid_block_res_sample is not None else None |
|
), |
|
return_dict=False, |
|
cross_attention_kwargs={ |
|
"mode": "r", |
|
"num_in_batch": N_gen, |
|
"dino_hidden_states": dino_hidden_states, |
|
"condition_embed_dict": condition_embed_dict, |
|
"mva_scale": mva_scale, |
|
"ref_scale": ref_scale, |
|
"position_voxel_indices": position_voxel_indices, |
|
}, |
|
) |
|
|