aroraaman's picture
Add all of `fourm`
3424266
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ControlNetModel
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.controlnet import zero_module
from fourm.utils import to_2tuple
from .lm_models import create_model
class ControlNetAdapterEmbedding(nn.Module):
def __init__(
self,
conditioning_embedding_channels,
adapter,
conditioning_channels=3,
):
super().__init__()
self.adapter_model = create_model(
in_channels=conditioning_channels,
output_type="stats",
)
self._load_adapter(adapter)
self.conv_out = zero_module(
nn.Conv2d(8, conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.adapter_model(quant=conditioning)
embedding = self.conv_out(embedding)
return embedding
def _load_adapter(self, path):
ckpt = torch.load(path)['model']
for key in list(ckpt.keys()):
if 'vq_model' in key or 'vae' in key:
del ckpt[key]
self.adapter_model.load_state_dict(ckpt)
print("Loaded the adapter model")
class ControlNetConditioningEmbedding(nn.Module):
def __init__(
self,
conditioning_embedding_channels,
conditioning_channels = 3,
block_out_channels = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlnetCond(ModelMixin, ConfigMixin):
def __init__(self,
in_channels,
cond_channels,
sd_pipeline,
image_size,
freeze_params=True,
block_out_channels = (320, 640, 1280, 1280),
conditioning_embedding_out_channels = (32, 32, 96, 256),
pretrained_cn=False,
enable_xformer=False,
adapter=None,
*args,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.cond_channels = cond_channels
self.sd_pipeline = sd_pipeline
self.unet = sd_pipeline.unet
self.text_encoder = sd_pipeline.text_encoder
self.tokenizer = sd_pipeline.tokenizer
if pretrained_cn:
self.controlnet = ControlNetModel.from_unet(self.unet, conditioning_embedding_out_channels=conditioning_embedding_out_channels)
self.controlnet.conditioning_channels = cond_channels
self.controlnet.config.conditioning_channels = cond_channels
else:
self.controlnet = ControlNetModel(
in_channels=in_channels,
conditioning_channels=cond_channels,
block_out_channels=block_out_channels,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
*args,
**kwargs,
)
self.use_adapter = adapter is not None
if adapter is not None:
self.controlnet.controlnet_cond_embedding = ControlNetAdapterEmbedding(
conditioning_embedding_channels=self.controlnet.config.block_out_channels[0],
adapter=adapter,
conditioning_channels=cond_channels,
)
else:
self.controlnet.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=self.controlnet.config.block_out_channels[0],
block_out_channels=self.controlnet.config.conditioning_embedding_out_channels,
conditioning_channels=cond_channels,
)
if enable_xformer:
print('xFormer enabled')
self.unet.enable_xformers_memory_efficient_attention()
self.controlnet.enable_xformers_memory_efficient_attention()
self.empty_str_encoding = nn.Parameter(self._encode_prompt(""), requires_grad=False)
if freeze_params:
self.freeze_params()
self.sample_size = image_size // sd_pipeline.vae_scale_factor
self.H, self.W = to_2tuple(self.sample_size)
def forward(self,
sample: torch.FloatTensor, # Shape (B, C, H, W),
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor = None, # Shape (B, D_C, H_C, W_C)
cond_mask: Optional[torch.BoolTensor] = None, # Boolen tensor of shape (B, H_C, W_C). True for masked out pixels,
prompt = None,
unconditional = False,
cond_scale = 1.0,
**kwargs):
# Optionally mask out conditioning
if cond_mask is not None:
encoder_hidden_states = torch.where(cond_mask[:,None,:,:], 0.0, encoder_hidden_states)
if not self.use_adapter:
controlnet_cond = F.interpolate(encoder_hidden_states, (self.H, self.W), mode="nearest")
else:
controlnet_cond = F.interpolate(encoder_hidden_states, (self.H // 2, self.W // 2), mode="nearest")
# encoder_hidden_states is the propmp embedding in the controlnet model, for now it's set to zeros.
if prompt is None or unconditional:
encoder_hidden_states = torch.cat([self.empty_str_encoding] * sample.shape[0])
else:
encoder_hidden_states = self._encode_prompt(prompt)
down_block_res_samples, mid_block_res_sample = self.controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_cond,
conditioning_scale=cond_scale,
return_dict=False,
)
# TODO not the most efficient way
if unconditional:
down_block_res_samples = [torch.zeros_like(s) for s in down_block_res_samples]
controlnet_cond = torch.zeros_like(controlnet_cond)
noise_pred = self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
return noise_pred
def freeze_params(self):
for param in self.unet.parameters():
param.requires_grad = False
for param in self.text_encoder.parameters():
param.requires_grad = False
def unfreeze_params(self):
for param in self.unet.parameters():
param.requires_grad = True
for param in self.text_encoder.parameters():
param.requires_grad = True
@torch.no_grad()
def _encode_prompt(self, prompt):
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(self.device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(self.device),
attention_mask=attention_mask,
)[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device)
return prompt_embeds
def controlnet(*args, **kwargs):
return ControlnetCond(
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=
['CrossAttnDownBlock2D',
'CrossAttnDownBlock2D',
'CrossAttnDownBlock2D',
'DownBlock2D'],
only_cross_attention=False,
block_out_channels=[320, 640, 1280, 1280],
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn='silu',
norm_num_groups=32,
norm_eps=1e-05,
cross_attention_dim=768,
attention_head_dim=8,
num_attention_heads=None,
use_linear_projection=False,
class_embed_type=None,
num_class_embeds=None,
upcast_attention=False,
resnet_time_scale_shift='default',
projection_class_embeddings_input_dim=None,
controlnet_conditioning_channel_order='rgb',
conditioning_embedding_out_channels=[kwargs['cond_channels'], 32, 96, 256],
global_pool_conditions=False,
freeze_params=True,
*args,
**kwargs,
)