Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import inspect | |
| import weakref | |
| import torch | |
| from typing import TYPE_CHECKING | |
| from toolkit.lora_special import LoRASpecialNetwork | |
| from diffusers import FluxTransformer2DModel | |
| # weakref | |
| from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer | |
| if TYPE_CHECKING: | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig | |
| from toolkit.custom_adapter import CustomAdapter | |
| class InOutModule(torch.nn.Module): | |
| def __init__( | |
| self, | |
| adapter: 'SubpixelAdapter', | |
| orig_layer: torch.nn.Linear, | |
| in_channels=64, | |
| out_channels=3072 | |
| ): | |
| super().__init__() | |
| # only do the weight for the new input. We combine with the original linear layer | |
| self.x_embedder = torch.nn.Linear( | |
| in_channels, | |
| out_channels, | |
| bias=True, | |
| ) | |
| self.proj_out = torch.nn.Linear( | |
| out_channels, | |
| in_channels, | |
| bias=True, | |
| ) | |
| # make sure the weight is float32 | |
| self.x_embedder.weight.data = self.x_embedder.weight.data.float() | |
| self.x_embedder.bias.data = self.x_embedder.bias.data.float() | |
| self.proj_out.weight.data = self.proj_out.weight.data.float() | |
| self.proj_out.bias.data = self.proj_out.bias.data.float() | |
| self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
| self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) | |
| def from_model( | |
| cls, | |
| model: FluxTransformer2DModel, | |
| adapter: 'SubpixelAdapter', | |
| num_channels: int = 768, | |
| downscale_factor: int = 8 | |
| ): | |
| if model.__class__.__name__ == 'FluxTransformer2DModel': | |
| x_embedder: torch.nn.Linear = model.x_embedder | |
| proj_out: torch.nn.Linear = model.proj_out | |
| in_out_module = cls( | |
| adapter, | |
| orig_layer=x_embedder, | |
| in_channels=num_channels, | |
| out_channels=x_embedder.out_features, | |
| ) | |
| # hijack the forward method | |
| x_embedder._orig_ctrl_lora_forward = x_embedder.forward | |
| x_embedder.forward = in_out_module.in_forward | |
| proj_out._orig_ctrl_lora_forward = proj_out.forward | |
| proj_out.forward = in_out_module.out_forward | |
| # update the config of the transformer | |
| model.config.in_channels = num_channels | |
| model.config["in_channels"] = num_channels | |
| model.config.out_channels = num_channels | |
| model.config["out_channels"] = num_channels | |
| # if the shape matches, copy the weights | |
| if x_embedder.weight.shape == in_out_module.x_embedder.weight.shape: | |
| in_out_module.x_embedder.weight.data = x_embedder.weight.data.clone().float() | |
| in_out_module.x_embedder.bias.data = x_embedder.bias.data.clone().float() | |
| in_out_module.proj_out.weight.data = proj_out.weight.data.clone().float() | |
| in_out_module.proj_out.bias.data = proj_out.bias.data.clone().float() | |
| # replace the vae of the model | |
| sd = adapter.sd_ref() | |
| sd.vae = AutoencoderPixelMixer( | |
| in_channels=3, | |
| downscale_factor=downscale_factor | |
| ) | |
| sd.pipeline.vae = sd.vae | |
| return in_out_module | |
| else: | |
| raise ValueError("Model not supported") | |
| def is_active(self): | |
| return self.adapter_ref().is_active | |
| def in_forward(self, x): | |
| if not self.is_active: | |
| # make sure lora is not active | |
| if self.adapter_ref().control_lora is not None: | |
| self.adapter_ref().control_lora.is_active = False | |
| return self.orig_layer_ref()._orig_ctrl_lora_forward(x) | |
| # make sure lora is active | |
| if self.adapter_ref().control_lora is not None: | |
| self.adapter_ref().control_lora.is_active = True | |
| orig_device = x.device | |
| orig_dtype = x.dtype | |
| x = x.to(self.x_embedder.weight.device, dtype=self.x_embedder.weight.dtype) | |
| x = self.x_embedder(x) | |
| x = x.to(orig_device, dtype=orig_dtype) | |
| return x | |
| def out_forward(self, x): | |
| if not self.is_active: | |
| # make sure lora is not active | |
| if self.adapter_ref().control_lora is not None: | |
| self.adapter_ref().control_lora.is_active = False | |
| return self.orig_layer_ref()._orig_ctrl_lora_forward(x) | |
| # make sure lora is active | |
| if self.adapter_ref().control_lora is not None: | |
| self.adapter_ref().control_lora.is_active = True | |
| orig_device = x.device | |
| orig_dtype = x.dtype | |
| x = x.to(self.proj_out.weight.device, dtype=self.proj_out.weight.dtype) | |
| x = self.proj_out(x) | |
| x = x.to(orig_device, dtype=orig_dtype) | |
| return x | |
| class SubpixelAdapter(torch.nn.Module): | |
| def __init__( | |
| self, | |
| adapter: 'CustomAdapter', | |
| sd: 'StableDiffusion', | |
| config: 'AdapterConfig', | |
| train_config: 'TrainConfig' | |
| ): | |
| super().__init__() | |
| self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
| self.sd_ref = weakref.ref(sd) | |
| self.model_config: ModelConfig = sd.model_config | |
| self.network_config = config.lora_config | |
| self.train_config = train_config | |
| self.device_torch = sd.device_torch | |
| self.control_lora = None | |
| if self.network_config is not None: | |
| network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs | |
| if hasattr(sd, 'target_lora_modules'): | |
| network_kwargs['target_lin_modules'] = sd.target_lora_modules | |
| if 'ignore_if_contains' not in network_kwargs: | |
| network_kwargs['ignore_if_contains'] = [] | |
| # always ignore x_embedder | |
| network_kwargs['ignore_if_contains'].append('transformer.x_embedder') | |
| network_kwargs['ignore_if_contains'].append('transformer.proj_out') | |
| self.control_lora = LoRASpecialNetwork( | |
| text_encoder=sd.text_encoder, | |
| unet=sd.unet, | |
| lora_dim=self.network_config.linear, | |
| multiplier=1.0, | |
| alpha=self.network_config.linear_alpha, | |
| train_unet=self.train_config.train_unet, | |
| train_text_encoder=self.train_config.train_text_encoder, | |
| conv_lora_dim=self.network_config.conv, | |
| conv_alpha=self.network_config.conv_alpha, | |
| is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, | |
| is_v2=self.model_config.is_v2, | |
| is_v3=self.model_config.is_v3, | |
| is_pixart=self.model_config.is_pixart, | |
| is_auraflow=self.model_config.is_auraflow, | |
| is_flux=self.model_config.is_flux, | |
| is_lumina2=self.model_config.is_lumina2, | |
| is_ssd=self.model_config.is_ssd, | |
| is_vega=self.model_config.is_vega, | |
| dropout=self.network_config.dropout, | |
| use_text_encoder_1=self.model_config.use_text_encoder_1, | |
| use_text_encoder_2=self.model_config.use_text_encoder_2, | |
| use_bias=False, | |
| is_lorm=False, | |
| network_config=self.network_config, | |
| network_type=self.network_config.type, | |
| transformer_only=self.network_config.transformer_only, | |
| is_transformer=sd.is_transformer, | |
| base_model=sd, | |
| **network_kwargs | |
| ) | |
| self.control_lora.force_to(self.device_torch, dtype=torch.float32) | |
| self.control_lora._update_torch_multiplier() | |
| self.control_lora.apply_to( | |
| sd.text_encoder, | |
| sd.unet, | |
| self.train_config.train_text_encoder, | |
| self.train_config.train_unet | |
| ) | |
| self.control_lora.can_merge_in = False | |
| self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) | |
| if self.train_config.gradient_checkpointing: | |
| self.control_lora.enable_gradient_checkpointing() | |
| downscale_factor = config.subpixel_downscale_factor | |
| if downscale_factor == 8: | |
| num_channels = 768 | |
| elif downscale_factor == 16: | |
| num_channels = 3072 | |
| else: | |
| raise ValueError( | |
| f"downscale_factor {downscale_factor} not supported" | |
| ) | |
| self.in_out: InOutModule = InOutModule.from_model( | |
| sd.unet_unwrapped, | |
| self, | |
| num_channels=num_channels, # packed channels | |
| downscale_factor=downscale_factor | |
| ) | |
| self.in_out.to(self.device_torch) | |
| def get_params(self): | |
| if self.control_lora is not None: | |
| config = { | |
| 'text_encoder_lr': self.train_config.lr, | |
| 'unet_lr': self.train_config.lr, | |
| } | |
| sig = inspect.signature(self.control_lora.prepare_optimizer_params) | |
| if 'default_lr' in sig.parameters: | |
| config['default_lr'] = self.train_config.lr | |
| if 'learning_rate' in sig.parameters: | |
| config['learning_rate'] = self.train_config.lr | |
| params_net = self.control_lora.prepare_optimizer_params( | |
| **config | |
| ) | |
| # we want only tensors here | |
| params = [] | |
| for p in params_net: | |
| if isinstance(p, dict): | |
| params += p["params"] | |
| elif isinstance(p, torch.Tensor): | |
| params.append(p) | |
| elif isinstance(p, list): | |
| params += p | |
| else: | |
| params = [] | |
| # make sure the embedder is float32 | |
| self.in_out.to(torch.float32) | |
| params += list(self.in_out.parameters()) | |
| # we need to be able to yield from the list like yield from params | |
| return params | |
| def load_weights(self, state_dict, strict=True): | |
| lora_sd = {} | |
| img_embedder_sd = {} | |
| for key, value in state_dict.items(): | |
| if "transformer.x_embedder" in key: | |
| new_key = key.replace("transformer.", "") | |
| img_embedder_sd[new_key] = value | |
| elif "transformer.proj_out" in key: | |
| new_key = key.replace("transformer.", "") | |
| img_embedder_sd[new_key] = value | |
| else: | |
| lora_sd[key] = value | |
| # todo process state dict before loading | |
| if self.control_lora is not None: | |
| self.control_lora.load_weights(lora_sd) | |
| # automatically upgrade the x imbedder if more dims are added | |
| self.in_out.load_state_dict(img_embedder_sd, strict=False) | |
| def get_state_dict(self): | |
| if self.control_lora is not None: | |
| lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) | |
| else: | |
| lora_sd = {} | |
| # todo make sure we match loras elseware. | |
| img_embedder_sd = self.in_out.state_dict() | |
| for key, value in img_embedder_sd.items(): | |
| lora_sd[f"transformer.{key}"] = value | |
| return lora_sd | |
| def is_active(self): | |
| return self.adapter_ref().is_active | |
