Spaces:
Running
Running
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import weakref | |
| from typing import Union, TYPE_CHECKING, Optional, Tuple | |
| from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer | |
| from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention | |
| from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule | |
| if TYPE_CHECKING: | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| from toolkit.custom_adapter import CustomAdapter | |
| class TEAugAdapterCLIPAttention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'): | |
| super().__init__() | |
| self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
| self.attn_module_ref: weakref.ref = weakref.ref(attn_module) | |
| self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) | |
| self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) | |
| # copy the weights from the original module | |
| self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01 | |
| self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01 | |
| #reset the bias | |
| self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001 | |
| self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001 | |
| self.zipper = ZipperModule( | |
| in_size=attn_module.embed_dim, | |
| in_tokens=77 * 2, | |
| out_size=attn_module.embed_dim, | |
| out_tokens=77, | |
| hidden_size=attn_module.embed_dim, | |
| hidden_tokens=77, | |
| ) | |
| # self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data) | |
| # self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data) | |
| # #reset the bias | |
| # self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data) | |
| # self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data) | |
| # replace the original forward with our forward | |
| self.original_forward = attn_module.forward | |
| attn_module.forward = self.forward | |
| def is_active(self): | |
| return self.adapter_ref().is_active | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| """Input shape: Batch x Time x Channel""" | |
| attn_module = self.attn_module_ref() | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| # get query proj | |
| query_states = attn_module.q_proj(hidden_states) * attn_module.scale | |
| key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz) | |
| value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz) | |
| proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim) | |
| query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
| key_states = key_states.view(*proj_shape) | |
| value_states = value_states.view(*proj_shape) | |
| src_len = key_states.size(1) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is" | |
| f" {attn_weights.size()}" | |
| ) | |
| # apply the causal_attention_mask first | |
| if causal_attention_mask is not None: | |
| if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | |
| f" {causal_attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask | |
| attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) | |
| if attention_mask is not None: | |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask | |
| attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| if output_attentions: | |
| # this operation is a bit akward, but it's required to | |
| # make sure that attn_weights keeps its gradient. | |
| # In order to do so, attn_weights have to reshaped | |
| # twice and have to be reused in the following | |
| attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) | |
| attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len) | |
| else: | |
| attn_weights_reshaped = None | |
| attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) | |
| attn_output = torch.bmm(attn_probs, value_states) | |
| if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): | |
| raise ValueError( | |
| f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" | |
| f" {attn_output.size()}" | |
| ) | |
| attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) | |
| attn_output = attn_output.transpose(1, 2) | |
| attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) | |
| adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref() | |
| if self.adapter_ref().is_active and adapter.conditional_embeds is not None: | |
| # apply the adapter | |
| if adapter.is_unconditional_run: | |
| embeds = adapter.unconditional_embeds | |
| else: | |
| embeds = adapter.conditional_embeds | |
| # if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well | |
| if embeds.size(0) != bsz: | |
| embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0) | |
| key_states_raw = self.k_proj_adapter(embeds) | |
| key_states = attn_module._shape(key_states_raw, -1, bsz) | |
| value_states_raw = self.v_proj_adapter(embeds) | |
| value_states = attn_module._shape(value_states_raw, -1, bsz) | |
| key_states = key_states.view(*proj_shape) | |
| value_states = value_states.view(*proj_shape) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) | |
| attn_output_adapter = torch.bmm(attn_probs, value_states) | |
| if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): | |
| raise ValueError( | |
| f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" | |
| f" {attn_output_adapter.size()}" | |
| ) | |
| attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) | |
| attn_output_adapter = attn_output_adapter.transpose(1, 2) | |
| attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim) | |
| attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1)) | |
| # attn_output_adapter = attn_module.out_proj(attn_output_adapter) | |
| attn_output = attn_output + attn_output_adapter | |
| attn_output = attn_module.out_proj(attn_output) | |
| return attn_output, attn_weights_reshaped | |
| class TEAugAdapter(torch.nn.Module): | |
| def __init__( | |
| self, | |
| adapter: 'CustomAdapter', | |
| sd: 'StableDiffusion', | |
| ): | |
| super(TEAugAdapter, self).__init__() | |
| self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
| self.sd_ref: weakref.ref = weakref.ref(sd) | |
| if isinstance(sd.text_encoder, list): | |
| raise ValueError("Dual text encoders is not yet supported") | |
| # dim will come from text encoder | |
| # dim = sd.unet.config['cross_attention_dim'] | |
| text_encoder: CLIPTextModel = sd.text_encoder | |
| dim = text_encoder.config.hidden_size | |
| clip_encoder: CLIPEncoder = text_encoder.text_model.encoder | |
| # dim = clip_encoder.layers[-1].self_attn | |
| if hasattr(adapter.vision_encoder.config, 'hidden_sizes'): | |
| embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] | |
| else: | |
| embedding_dim = adapter.vision_encoder.config.hidden_size | |
| image_encoder_state_dict = adapter.vision_encoder.state_dict() | |
| # max_seq_len = CLIP tokens + CLS token | |
| in_tokens = 257 | |
| if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: | |
| # clip | |
| in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) | |
| if adapter.config.image_encoder_arch.startswith('convnext'): | |
| in_tokens = 16 * 16 | |
| embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] | |
| out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens | |
| self.image_proj_model = ZipperModule( | |
| in_size=embedding_dim, | |
| in_tokens=in_tokens, | |
| out_size=dim, | |
| out_tokens=out_tokens, | |
| hidden_size=dim, | |
| hidden_tokens=out_tokens, | |
| ) | |
| # init adapter modules | |
| attn_procs = {} | |
| for idx, layer in enumerate(clip_encoder.layers): | |
| name = f"clip_attention.{idx}" | |
| attn_procs[name] = TEAugAdapterCLIPAttention( | |
| layer.self_attn, | |
| self | |
| ) | |
| self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values())) | |
| # make a getter to see if is active | |
| def is_active(self): | |
| return self.adapter_ref().is_active | |
| def forward(self, input): | |
| # # apply the adapter | |
| input = self.image_proj_model(input) | |
| # self.embeds = input | |
| return input | |