Spaces:
Paused
Paused
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from einops import rearrange | |
import torch.nn.functional as F | |
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock | |
from diffusers.utils import is_torch_version | |
from diffusers.loaders import PeftAdapterMixin | |
from diffusers.utils.torch_utils import maybe_allow_in_graph | |
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.models.attention import Attention, FeedForward | |
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0 | |
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
class CogVideoXControlnetPCD(ModelMixin, ConfigMixin, PeftAdapterMixin): | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
num_attention_heads: int = 30, | |
use_zero_conv: bool = False, | |
attention_head_dim: int = 64, | |
vae_channels: int = 16, | |
in_channels: int = 3, | |
downscale_coef: int = 8, | |
flip_sin_to_cos: bool = True, | |
freq_shift: int = 0, | |
time_embed_dim: int = 512, | |
num_layers: int = 8, | |
dropout: float = 0.0, | |
attention_bias: bool = True, | |
sample_width: int = 90, | |
sample_height: int = 60, | |
sample_frames: int = 49, | |
patch_size: int = 2, | |
temporal_compression_ratio: int = 4, | |
max_text_seq_length: int = 226, | |
activation_fn: str = "gelu-approximate", | |
timestep_activation_fn: str = "silu", | |
norm_elementwise_affine: bool = True, | |
norm_eps: float = 1e-5, | |
spatial_interpolation_scale: float = 1.875, | |
temporal_interpolation_scale: float = 1.0, | |
use_rotary_positional_embeddings: bool = False, | |
use_learned_positional_embeddings: bool = False, | |
out_proj_dim: int = None, | |
out_proj_dim_zero_init: bool = False, | |
): | |
super().__init__() | |
inner_dim = num_attention_heads * attention_head_dim | |
if not use_rotary_positional_embeddings and use_learned_positional_embeddings: | |
raise ValueError( | |
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " | |
"embeddings. If you're using a custom model and/or believe this should be supported, please open an " | |
"issue at https://github.com/huggingface/diffusers/issues." | |
) | |
self.vae_channels = vae_channels | |
start_channels = in_channels * (downscale_coef ** 2) | |
input_channels = [start_channels, start_channels // 2, start_channels // 4] | |
self.unshuffle = nn.PixelUnshuffle(downscale_coef) | |
self.use_zero_conv = use_zero_conv | |
if use_zero_conv: | |
self.controlnet_encode_first = nn.Sequential( | |
nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0), | |
nn.GroupNorm(2, input_channels[1]), | |
nn.ReLU(), | |
) | |
self.controlnet_encode_second = nn.Sequential( | |
nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0), | |
nn.GroupNorm(2, input_channels[2]), | |
nn.ReLU(), | |
) | |
patch_embed_in_channels = vae_channels + input_channels[2] | |
else: | |
patch_embed_in_channels = vae_channels*2 | |
# 1. Patch embedding | |
# self.patch_embed = CogVideoXPatchEmbed( | |
# patch_size=patch_size, | |
# in_channels=patch_embed_in_channels, | |
# embed_dim=inner_dim, | |
# bias=True, | |
# sample_width=sample_width, | |
# sample_height=sample_height, | |
# sample_frames=sample_frames, | |
# temporal_compression_ratio=temporal_compression_ratio, | |
# spatial_interpolation_scale=spatial_interpolation_scale, | |
# temporal_interpolation_scale=temporal_interpolation_scale, | |
# use_positional_embeddings=not use_rotary_positional_embeddings, | |
# use_learned_positional_embeddings=use_learned_positional_embeddings, | |
# ) | |
self.patch_embed = CogVideoXPatchEmbed( | |
patch_size=patch_size, | |
in_channels=patch_embed_in_channels, | |
embed_dim=inner_dim, | |
bias=True, | |
sample_width=sample_width, | |
sample_height=sample_height, | |
sample_frames=sample_frames, | |
temporal_compression_ratio=temporal_compression_ratio, | |
spatial_interpolation_scale=spatial_interpolation_scale, | |
temporal_interpolation_scale=temporal_interpolation_scale, | |
use_positional_embeddings=not use_rotary_positional_embeddings, | |
use_learned_positional_embeddings=use_learned_positional_embeddings, | |
) | |
self.embedding_dropout = nn.Dropout(dropout) | |
# 2. Time embeddings | |
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) | |
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) | |
# 3. Define spatio-temporal transformers blocks | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
CogVideoXBlock( | |
dim=inner_dim, | |
num_attention_heads=num_attention_heads, | |
attention_head_dim=attention_head_dim, | |
time_embed_dim=time_embed_dim, | |
dropout=dropout, | |
activation_fn=activation_fn, | |
attention_bias=attention_bias, | |
norm_elementwise_affine=norm_elementwise_affine, | |
norm_eps=norm_eps, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.out_projectors = None | |
if out_proj_dim is not None: | |
self.out_projectors = nn.ModuleList( | |
[nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)] | |
) | |
if out_proj_dim_zero_init: | |
for out_projector in self.out_projectors: | |
self.zeros_init_linear(out_projector) | |
self.gradient_checkpointing = False | |
def zeros_init_linear(self, linear: nn.Module): | |
if isinstance(linear, (nn.Linear, nn.Conv1d)): | |
if hasattr(linear, "weight"): | |
nn.init.zeros_(linear.weight) | |
if hasattr(linear, "bias"): | |
nn.init.zeros_(linear.bias) | |
def _set_gradient_checkpointing(self, module, value=False): | |
self.gradient_checkpointing = value | |
def compress_time(self, x, num_frames): | |
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) | |
batch_size, frames, channels, height, width = x.shape | |
x = rearrange(x, 'b f c h w -> (b h w) c f') | |
if x.shape[-1] % 2 == 1: | |
x_first, x_rest = x[..., 0], x[..., 1:] | |
if x_rest.shape[-1] > 0: | |
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) | |
x = torch.cat([x_first[..., None], x_rest], dim=-1) | |
else: | |
x = F.avg_pool1d(x, kernel_size=2, stride=2) | |
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width) | |
return x | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
controlnet_states: Tuple[torch.Tensor, torch.Tensor], | |
timestep: Union[int, float, torch.LongTensor], | |
controlnet_output_mask: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
): | |
print("hidden_states.shape =", hidden_states.shape) | |
print("controlnet_states.shape =", controlnet_states.shape) | |
hidden_states = torch.cat([hidden_states, controlnet_states], dim=2) | |
# controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep) | |
# 1. Time embedding | |
timesteps = timestep | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=hidden_states.dtype) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) | |
hidden_states = self.embedding_dropout(hidden_states) | |
text_seq_length = encoder_hidden_states.shape[1] | |
encoder_hidden_states = hidden_states[:, :text_seq_length] | |
hidden_states = hidden_states[:, text_seq_length:] | |
controlnet_hidden_states = () | |
# 3. Transformer blocks | |
for i, block in enumerate(self.transformer_blocks): | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
encoder_hidden_states, | |
emb, | |
image_rotary_emb, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states, encoder_hidden_states = block( | |
hidden_states=hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
temb=emb, | |
image_rotary_emb=image_rotary_emb, | |
) | |
if self.out_projectors is not None: | |
if controlnet_output_mask is not None: | |
controlnet_hidden_states += (self.out_projectors[i](hidden_states) * controlnet_output_mask,) | |
else: | |
controlnet_hidden_states += (self.out_projectors[i](hidden_states),) | |
else: | |
controlnet_hidden_states += (hidden_states,) | |
if not return_dict: | |
return (controlnet_hidden_states,) | |
return Transformer2DModelOutput(sample=controlnet_hidden_states) |