Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import math | |
| import weakref | |
| import torch | |
| import torch.nn as nn | |
| from typing import TYPE_CHECKING, List, Dict, Any | |
| from toolkit.models.clip_fusion import ZipperBlock | |
| from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler | |
| import sys | |
| from collections import OrderedDict | |
| if TYPE_CHECKING: | |
| from toolkit.lora_special import LoRAModule | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, nhead, dim_feedforward): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) | |
| self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) | |
| self.feed_forward = nn.Sequential( | |
| nn.Linear(d_model, dim_feedforward), | |
| nn.ReLU(), | |
| nn.Linear(dim_feedforward, d_model) | |
| ) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| def forward(self, x, cross_attn_input): | |
| # Self-attention | |
| attn_output, _ = self.self_attn(x, x, x) | |
| x = self.norm1(x + attn_output) | |
| # Cross-attention | |
| cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input) | |
| x = self.norm2(x + cross_attn_output) | |
| # Feed-forward | |
| ff_output = self.feed_forward(x) | |
| x = self.norm3(x + ff_output) | |
| return x | |
| class InstantLoRAMidModule(torch.nn.Module): | |
| def __init__( | |
| self, | |
| index: int, | |
| lora_module: 'LoRAModule', | |
| instant_lora_module: 'InstantLoRAModule', | |
| up_shape: list = None, | |
| down_shape: list = None, | |
| ): | |
| super(InstantLoRAMidModule, self).__init__() | |
| self.up_shape = up_shape | |
| self.down_shape = down_shape | |
| self.index = index | |
| self.lora_module_ref = weakref.ref(lora_module) | |
| self.instant_lora_module_ref = weakref.ref(instant_lora_module) | |
| self.embed = None | |
| def down_forward(self, x, *args, **kwargs): | |
| # get the embed | |
| self.embed = self.instant_lora_module_ref().img_embeds[self.index] | |
| down_size = math.prod(self.down_shape) | |
| down_weight = self.embed[:, :down_size] | |
| batch_size = x.shape[0] | |
| # unconditional | |
| if down_weight.shape[0] * 2 == batch_size: | |
| down_weight = torch.cat([down_weight] * 2, dim=0) | |
| weight_chunks = torch.chunk(down_weight, batch_size, dim=0) | |
| x_chunks = torch.chunk(x, batch_size, dim=0) | |
| x_out = [] | |
| for i in range(batch_size): | |
| weight_chunk = weight_chunks[i] | |
| x_chunk = x_chunks[i] | |
| # reshape | |
| weight_chunk = weight_chunk.view(self.down_shape) | |
| # check if is conv or linear | |
| if len(weight_chunk.shape) == 4: | |
| padding = 0 | |
| if weight_chunk.shape[-1] == 3: | |
| padding = 1 | |
| x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) | |
| else: | |
| # run a simple linear layer with the down weight | |
| x_chunk = x_chunk @ weight_chunk.T | |
| x_out.append(x_chunk) | |
| x = torch.cat(x_out, dim=0) | |
| return x | |
| def up_forward(self, x, *args, **kwargs): | |
| self.embed = self.instant_lora_module_ref().img_embeds[self.index] | |
| up_size = math.prod(self.up_shape) | |
| up_weight = self.embed[:, -up_size:] | |
| batch_size = x.shape[0] | |
| # unconditional | |
| if up_weight.shape[0] * 2 == batch_size: | |
| up_weight = torch.cat([up_weight] * 2, dim=0) | |
| weight_chunks = torch.chunk(up_weight, batch_size, dim=0) | |
| x_chunks = torch.chunk(x, batch_size, dim=0) | |
| x_out = [] | |
| for i in range(batch_size): | |
| weight_chunk = weight_chunks[i] | |
| x_chunk = x_chunks[i] | |
| # reshape | |
| weight_chunk = weight_chunk.view(self.up_shape) | |
| # check if is conv or linear | |
| if len(weight_chunk.shape) == 4: | |
| padding = 0 | |
| if weight_chunk.shape[-1] == 3: | |
| padding = 1 | |
| x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) | |
| else: | |
| # run a simple linear layer with the down weight | |
| x_chunk = x_chunk @ weight_chunk.T | |
| x_out.append(x_chunk) | |
| x = torch.cat(x_out, dim=0) | |
| return x | |
| # Initialize the network | |
| # num_blocks = 8 | |
| # d_model = 1024 # Adjust as needed | |
| # nhead = 16 # Adjust as needed | |
| # dim_feedforward = 4096 # Adjust as needed | |
| # latent_dim = 1695744 | |
| class LoRAFormer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_blocks, | |
| d_model=1024, | |
| nhead=16, | |
| dim_feedforward=4096, | |
| sd: 'StableDiffusion'=None, | |
| ): | |
| super(LoRAFormer, self).__init__() | |
| # self.linear = torch.nn.Linear(2, 1) | |
| self.sd_ref = weakref.ref(sd) | |
| self.dim = sd.network.lora_dim | |
| # stores the projection vector. Grabbed by modules | |
| self.img_embeds: List[torch.Tensor] = None | |
| # disable merging in. It is slower on inference | |
| self.sd_ref().network.can_merge_in = False | |
| self.ilora_modules = torch.nn.ModuleList() | |
| lora_modules = self.sd_ref().network.get_all_modules() | |
| output_size = 0 | |
| self.embed_lengths = [] | |
| self.weight_mapping = [] | |
| for idx, lora_module in enumerate(lora_modules): | |
| module_dict = lora_module.state_dict() | |
| down_shape = list(module_dict['lora_down.weight'].shape) | |
| up_shape = list(module_dict['lora_up.weight'].shape) | |
| self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) | |
| module_size = math.prod(down_shape) + math.prod(up_shape) | |
| output_size += module_size | |
| self.embed_lengths.append(module_size) | |
| # add a new mid module that will take the original forward and add a vector to it | |
| # this will be used to add the vector to the original forward | |
| instant_module = InstantLoRAMidModule( | |
| idx, | |
| lora_module, | |
| self, | |
| up_shape=up_shape, | |
| down_shape=down_shape | |
| ) | |
| self.ilora_modules.append(instant_module) | |
| # replace the LoRA forwards | |
| lora_module.lora_down.forward = instant_module.down_forward | |
| lora_module.lora_up.forward = instant_module.up_forward | |
| self.output_size = output_size | |
| self.latent = nn.Parameter(torch.randn(1, output_size)) | |
| self.latent_proj = nn.Linear(output_size, d_model) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(d_model, nhead, dim_feedforward) | |
| for _ in range(num_blocks) | |
| ]) | |
| self.final_proj = nn.Linear(d_model, output_size) | |
| self.migrate_weight_mapping() | |
| def migrate_weight_mapping(self): | |
| return | |
| # # changes the names of the modules to common ones | |
| # keymap = self.sd_ref().network.get_keymap() | |
| # save_keymap = {} | |
| # if keymap is not None: | |
| # for ldm_key, diffusers_key in keymap.items(): | |
| # # invert them | |
| # save_keymap[diffusers_key] = ldm_key | |
| # | |
| # new_keymap = {} | |
| # for key, value in self.weight_mapping: | |
| # if key in save_keymap: | |
| # new_keymap[save_keymap[key]] = value | |
| # else: | |
| # print(f"Key {key} not found in keymap") | |
| # new_keymap[key] = value | |
| # self.weight_mapping = new_keymap | |
| # else: | |
| # print("No keymap found. Using default names") | |
| # return | |
| def forward(self, img_embeds): | |
| # expand token rank if only rank 2 | |
| if len(img_embeds.shape) == 2: | |
| img_embeds = img_embeds.unsqueeze(1) | |
| # resample the image embeddings | |
| img_embeds = self.resampler(img_embeds) | |
| img_embeds = self.proj_module(img_embeds) | |
| if len(img_embeds.shape) == 3: | |
| # merge the heads | |
| img_embeds = img_embeds.mean(dim=1) | |
| self.img_embeds = [] | |
| # get all the slices | |
| start = 0 | |
| for length in self.embed_lengths: | |
| self.img_embeds.append(img_embeds[:, start:start+length]) | |
| start += length | |
| def get_additional_save_metadata(self) -> Dict[str, Any]: | |
| # save the weight mapping | |
| return { | |
| "weight_mapping": self.weight_mapping, | |
| "num_heads": self.num_heads, | |
| "vision_hidden_size": self.vision_hidden_size, | |
| "head_dim": self.head_dim, | |
| "vision_tokens": self.vision_tokens, | |
| "output_size": self.output_size, | |
| } | |
