Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Union, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models import ControlNetModel | |
| from diffusers.configuration_utils import ConfigMixin | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.controlnet import zero_module | |
| from fourm.utils import to_2tuple | |
| from .lm_models import create_model | |
| class ControlNetAdapterEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| conditioning_embedding_channels, | |
| adapter, | |
| conditioning_channels=3, | |
| ): | |
| super().__init__() | |
| self.adapter_model = create_model( | |
| in_channels=conditioning_channels, | |
| output_type="stats", | |
| ) | |
| self._load_adapter(adapter) | |
| self.conv_out = zero_module( | |
| nn.Conv2d(8, conditioning_embedding_channels, kernel_size=3, padding=1) | |
| ) | |
| def forward(self, conditioning): | |
| embedding = self.adapter_model(quant=conditioning) | |
| embedding = self.conv_out(embedding) | |
| return embedding | |
| def _load_adapter(self, path): | |
| ckpt = torch.load(path)['model'] | |
| for key in list(ckpt.keys()): | |
| if 'vq_model' in key or 'vae' in key: | |
| del ckpt[key] | |
| self.adapter_model.load_state_dict(ckpt) | |
| print("Loaded the adapter model") | |
| class ControlNetConditioningEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| conditioning_embedding_channels, | |
| conditioning_channels = 3, | |
| block_out_channels = (16, 32, 96, 256), | |
| ): | |
| super().__init__() | |
| self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) | |
| self.blocks = nn.ModuleList([]) | |
| for i in range(len(block_out_channels) - 1): | |
| channel_in = block_out_channels[i] | |
| channel_out = block_out_channels[i + 1] | |
| self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) | |
| self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1)) | |
| self.conv_out = zero_module( | |
| nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) | |
| ) | |
| def forward(self, conditioning): | |
| embedding = self.conv_in(conditioning) | |
| embedding = F.silu(embedding) | |
| for block in self.blocks: | |
| embedding = block(embedding) | |
| embedding = F.silu(embedding) | |
| embedding = self.conv_out(embedding) | |
| return embedding | |
| class ControlnetCond(ModelMixin, ConfigMixin): | |
| def __init__(self, | |
| in_channels, | |
| cond_channels, | |
| sd_pipeline, | |
| image_size, | |
| freeze_params=True, | |
| block_out_channels = (320, 640, 1280, 1280), | |
| conditioning_embedding_out_channels = (32, 32, 96, 256), | |
| pretrained_cn=False, | |
| enable_xformer=False, | |
| adapter=None, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.cond_channels = cond_channels | |
| self.sd_pipeline = sd_pipeline | |
| self.unet = sd_pipeline.unet | |
| self.text_encoder = sd_pipeline.text_encoder | |
| self.tokenizer = sd_pipeline.tokenizer | |
| if pretrained_cn: | |
| self.controlnet = ControlNetModel.from_unet(self.unet, conditioning_embedding_out_channels=conditioning_embedding_out_channels) | |
| self.controlnet.conditioning_channels = cond_channels | |
| self.controlnet.config.conditioning_channels = cond_channels | |
| else: | |
| self.controlnet = ControlNetModel( | |
| in_channels=in_channels, | |
| conditioning_channels=cond_channels, | |
| block_out_channels=block_out_channels, | |
| conditioning_embedding_out_channels=conditioning_embedding_out_channels, | |
| *args, | |
| **kwargs, | |
| ) | |
| self.use_adapter = adapter is not None | |
| if adapter is not None: | |
| self.controlnet.controlnet_cond_embedding = ControlNetAdapterEmbedding( | |
| conditioning_embedding_channels=self.controlnet.config.block_out_channels[0], | |
| adapter=adapter, | |
| conditioning_channels=cond_channels, | |
| ) | |
| else: | |
| self.controlnet.controlnet_cond_embedding = ControlNetConditioningEmbedding( | |
| conditioning_embedding_channels=self.controlnet.config.block_out_channels[0], | |
| block_out_channels=self.controlnet.config.conditioning_embedding_out_channels, | |
| conditioning_channels=cond_channels, | |
| ) | |
| if enable_xformer: | |
| print('xFormer enabled') | |
| self.unet.enable_xformers_memory_efficient_attention() | |
| self.controlnet.enable_xformers_memory_efficient_attention() | |
| self.empty_str_encoding = nn.Parameter(self._encode_prompt(""), requires_grad=False) | |
| if freeze_params: | |
| self.freeze_params() | |
| self.sample_size = image_size // sd_pipeline.vae_scale_factor | |
| self.H, self.W = to_2tuple(self.sample_size) | |
| def forward(self, | |
| sample: torch.FloatTensor, # Shape (B, C, H, W), | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor = None, # Shape (B, D_C, H_C, W_C) | |
| cond_mask: Optional[torch.BoolTensor] = None, # Boolen tensor of shape (B, H_C, W_C). True for masked out pixels, | |
| prompt = None, | |
| unconditional = False, | |
| cond_scale = 1.0, | |
| **kwargs): | |
| # Optionally mask out conditioning | |
| if cond_mask is not None: | |
| encoder_hidden_states = torch.where(cond_mask[:,None,:,:], 0.0, encoder_hidden_states) | |
| if not self.use_adapter: | |
| controlnet_cond = F.interpolate(encoder_hidden_states, (self.H, self.W), mode="nearest") | |
| else: | |
| controlnet_cond = F.interpolate(encoder_hidden_states, (self.H // 2, self.W // 2), mode="nearest") | |
| # encoder_hidden_states is the propmp embedding in the controlnet model, for now it's set to zeros. | |
| if prompt is None or unconditional: | |
| encoder_hidden_states = torch.cat([self.empty_str_encoding] * sample.shape[0]) | |
| else: | |
| encoder_hidden_states = self._encode_prompt(prompt) | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| sample, | |
| timestep, | |
| encoder_hidden_states=encoder_hidden_states, | |
| controlnet_cond=controlnet_cond, | |
| conditioning_scale=cond_scale, | |
| return_dict=False, | |
| ) | |
| # TODO not the most efficient way | |
| if unconditional: | |
| down_block_res_samples = [torch.zeros_like(s) for s in down_block_res_samples] | |
| controlnet_cond = torch.zeros_like(controlnet_cond) | |
| noise_pred = self.unet( | |
| sample, | |
| timestep, | |
| encoder_hidden_states=encoder_hidden_states, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| return_dict=False, | |
| )[0] | |
| return noise_pred | |
| def freeze_params(self): | |
| for param in self.unet.parameters(): | |
| param.requires_grad = False | |
| for param in self.text_encoder.parameters(): | |
| param.requires_grad = False | |
| def unfreeze_params(self): | |
| for param in self.unet.parameters(): | |
| param.requires_grad = True | |
| for param in self.text_encoder.parameters(): | |
| param.requires_grad = True | |
| def _encode_prompt(self, prompt): | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
| attention_mask = text_inputs.attention_mask.to(self.device) | |
| else: | |
| attention_mask = None | |
| prompt_embeds = self.text_encoder( | |
| text_input_ids.to(self.device), | |
| attention_mask=attention_mask, | |
| )[0] | |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) | |
| return prompt_embeds | |
| def controlnet(*args, **kwargs): | |
| return ControlnetCond( | |
| flip_sin_to_cos=True, | |
| freq_shift=0, | |
| down_block_types= | |
| ['CrossAttnDownBlock2D', | |
| 'CrossAttnDownBlock2D', | |
| 'CrossAttnDownBlock2D', | |
| 'DownBlock2D'], | |
| only_cross_attention=False, | |
| block_out_channels=[320, 640, 1280, 1280], | |
| layers_per_block=2, | |
| downsample_padding=1, | |
| mid_block_scale_factor=1, | |
| act_fn='silu', | |
| norm_num_groups=32, | |
| norm_eps=1e-05, | |
| cross_attention_dim=768, | |
| attention_head_dim=8, | |
| num_attention_heads=None, | |
| use_linear_projection=False, | |
| class_embed_type=None, | |
| num_class_embeds=None, | |
| upcast_attention=False, | |
| resnet_time_scale_shift='default', | |
| projection_class_embeddings_input_dim=None, | |
| controlnet_conditioning_channel_order='rgb', | |
| conditioning_embedding_out_channels=[kwargs['cond_channels'], 32, 96, 256], | |
| global_pool_conditions=False, | |
| freeze_params=True, | |
| *args, | |
| **kwargs, | |
| ) |