Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| sys.path.append(os.getcwd()) | |
| import yaml | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List, Tuple, Optional | |
| import numpy as np | |
| import lpips | |
| from torchvision import transforms | |
| from PIL import Image | |
| from peft import LoraConfig, get_peft_model | |
| from copy import deepcopy | |
| from tqdm import tqdm | |
| from diffusers import StableDiffusion3Pipeline, FluxPipeline | |
| from lora.lora_layers import LoraInjectedLinear, LoraInjectedConv2d | |
| def inject_lora_vae(vae, lora_rank=4, init_lora_weights="gaussian", verbose=False): | |
| """ | |
| Inject LoRA into the VAE's encoder | |
| """ | |
| vae.requires_grad_(False) | |
| vae.train() | |
| # Identify modules to LoRA-ify in the encoder | |
| l_grep = ["conv1", "conv2", "conv_in", "conv_shortcut", | |
| "conv", "conv_out", "to_k", "to_q", "to_v", "to_out.0"] | |
| l_target_modules_encoder = [] | |
| for n, p in vae.named_parameters(): | |
| if "bias" in n or "norm" in n: | |
| continue | |
| for pattern in l_grep: | |
| if (pattern in n) and ("encoder" in n): | |
| l_target_modules_encoder.append(n.replace(".weight", "")) | |
| elif ("quant_conv" in n) and ("post_quant_conv" not in n): | |
| l_target_modules_encoder.append(n.replace(".weight", "")) | |
| if verbose: | |
| print("The following VAE parameters will get LoRA:") | |
| print(l_target_modules_encoder) | |
| # Create and add a LoRA adapter | |
| lora_conf_encoder = LoraConfig( | |
| r=lora_rank, | |
| init_lora_weights=init_lora_weights, | |
| target_modules=l_target_modules_encoder | |
| ) | |
| adapter_name = "default_encoder" | |
| try: | |
| vae.add_adapter(lora_conf_encoder, adapter_name=adapter_name) | |
| vae.set_adapter(adapter_name) | |
| except ValueError as e: | |
| if "already exists" in str(e): | |
| print(f"Adapter with name {adapter_name} already exists. Skipping injection.") | |
| else: | |
| raise e | |
| return vae, l_target_modules_encoder | |
| def _find_modules(model, ancestor_class=None, search_class=[nn.Linear], exclude_children_of=[LoraInjectedLinear]): | |
| # Get the targets we should replace all linears under | |
| if ancestor_class is not None: | |
| ancestors = ( | |
| module | |
| for module in model.modules() | |
| if module.__class__.__name__ in ancestor_class | |
| ) | |
| else: | |
| # this, in case you want to naively iterate over all modules. | |
| ancestors = [module for module in model.modules()] | |
| for ancestor in ancestors: | |
| for fullname, module in ancestor.named_modules(): | |
| if any([isinstance(module, _class) for _class in search_class]): | |
| *path, name = fullname.split(".") | |
| parent = ancestor | |
| while path: | |
| parent = parent.get_submodule(path.pop(0)) | |
| if exclude_children_of and any( | |
| [isinstance(parent, _class) for _class in exclude_children_of] | |
| ): | |
| continue | |
| yield parent, name, module | |
| def inject_lora(model, ancestor_class, loras=None, r:int=4, dropout_p:float=0.0, scale:float=1.0, verbose:bool=False): | |
| model.requires_grad_(False) | |
| model.train() | |
| names = [] | |
| require_grad_params = [] # to be updated | |
| total_lora_params = 0 | |
| if loras is not None: | |
| loras = torch.load(loras, map_location=model.device, weights_only=True) | |
| loras = [lora.float() for lora in loras] | |
| for _module, name, _child_module in _find_modules(model, ancestor_class): # SiLU + Linear Block | |
| weight = _child_module.weight | |
| bias = _child_module.bias | |
| if verbose: | |
| print(f'LoRA Injection : injecting lora into {name}') | |
| _tmp = LoraInjectedLinear( | |
| _child_module.in_features, | |
| _child_module.out_features, | |
| _child_module.bias is not None, | |
| r=r, | |
| dropout_p=dropout_p, | |
| scale=scale, | |
| ) | |
| _tmp.linear.weight = nn.Parameter(weight.float()) | |
| if bias is not None: | |
| _tmp.linear.bias = nn.Parameter(bias.float()) | |
| # switch the module | |
| _tmp.to(device=_child_module.weight.device, dtype=torch.float) # keep as float / mixed precision | |
| _module._modules[name] = _tmp | |
| require_grad_params.append(_module._modules[name].lora_up.parameters()) | |
| require_grad_params.append(_module._modules[name].lora_down.parameters()) | |
| if loras != None: | |
| _module._modules[name].lora_up.weight = nn.Parameter(loras.pop(0)) | |
| _module._modules[name].lora_down.weight = nn.Parameter(loras.pop(0)) | |
| _module._modules[name].lora_up.weight.requires_grad = True | |
| _module._modules[name].lora_down.weight.requires_grad = True | |
| names.append(name) | |
| if verbose: | |
| # -------- Count LoRA parameters just added -------- | |
| lora_up_count = sum(p.numel() for p in _tmp.lora_up.parameters()) | |
| lora_down_count = sum(p.numel() for p in _tmp.lora_down.parameters()) | |
| lora_total_for_this_layer = lora_up_count + lora_down_count | |
| total_lora_params += lora_total_for_this_layer | |
| print(f" Added {lora_total_for_this_layer} params " | |
| f"(lora_up={lora_up_count}, lora_down={lora_down_count})") | |
| if verbose: | |
| print(f"Total new LoRA parameters added: {total_lora_params}") | |
| return require_grad_params, names | |
| def add_mp_hook(transformer): | |
| ''' | |
| For mixed precision of LoRA. (i.e. keep LoRA as float and others as half) | |
| ''' | |
| def pre_hook(module, input): | |
| return input.float() | |
| def post_hook(module, input, output): | |
| return output.half() | |
| hooks = [] | |
| for _module, name, _child_module in _find_modules(transformer): | |
| if isinstance(_child_module, LoraInjectedLinear): | |
| hook = _child_module.lora_up.register_forward_pre_hook(pre_hook) | |
| hooks.append(hook) | |
| hook = _child_module.lora_down.register_forward_hook(post_hook) | |
| hooks.append(hook) | |
| return transformer, hooks | |
| def compute_density_for_timestep_sampling( | |
| weighting_scheme: str, batch_size: int, logit_mean: float = 0.0, logit_std: float = 1.0, mode_scale: Optional[float] = None | |
| ): | |
| """ | |
| Compute the density for sampling the timesteps when doing SD3 training. | |
| Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. | |
| SD3 paper reference: https://arxiv.org/abs/2403.03206v1. | |
| """ | |
| if weighting_scheme == "logit_normal": | |
| # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). | |
| u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") | |
| u = torch.nn.functional.sigmoid(u) | |
| elif weighting_scheme == "mode": | |
| u = torch.rand(size=(batch_size,), device="cpu") | |
| u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) | |
| else: | |
| u = torch.rand(size=(batch_size,), device="cpu") | |
| return u | |
| def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas): | |
| """ | |
| Computes loss weighting scheme for SD3 training. | |
| Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. | |
| SD3 paper reference: https://arxiv.org/abs/2403.03206v1. | |
| """ | |
| if weighting_scheme == "sigma_sqrt": | |
| weighting = (sigmas**-2.0).float() | |
| elif weighting_scheme == "cosmap": | |
| bot = 1 - 2 * sigmas + 2 * sigmas**2 | |
| weighting = 2 / (math.pi * bot) | |
| else: | |
| weighting = torch.ones_like(sigmas) | |
| return weighting | |
| class StableDiffusion3Base(): | |
| def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16): | |
| self.device = device | |
| self.dtype = dtype | |
| pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype) | |
| self.scheduler = pipe.scheduler | |
| self.tokenizer_1 = pipe.tokenizer | |
| self.tokenizer_2 = pipe.tokenizer_2 | |
| self.tokenizer_3 = pipe.tokenizer_3 | |
| self.text_enc_1 = pipe.text_encoder.to(device) | |
| self.text_enc_2 = pipe.text_encoder_2.to(device) | |
| self.text_enc_3 = pipe.text_encoder_3.to(device) | |
| self.vae=pipe.vae.to(device) | |
| self.transformer = pipe.transformer.to(device) | |
| self.transformer.eval() | |
| self.transformer.requires_grad_(False) | |
| self.vae_scale_factor = ( | |
| 2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8 | |
| ) | |
| del pipe | |
| def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]: | |
| ''' | |
| We assume that | |
| 1. number of tokens < max_length | |
| 2. one prompt for one image | |
| ''' | |
| # CLIP encode (used for modulation of adaLN-zero) | |
| # now, we have two CLIPs | |
| text_clip1_ids = self.tokenizer_1(prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors='pt').input_ids | |
| text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.device), output_hidden_states=True) | |
| pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.device) | |
| text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.device) | |
| text_clip2_ids = self.tokenizer_2(prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_tensors='pt').input_ids | |
| text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.device), output_hidden_states=True) | |
| pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.device) | |
| text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.device) | |
| # T5 encode (used for text condition) | |
| text_t5_ids = self.tokenizer_3(prompt, | |
| padding="max_length", | |
| max_length=512, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors='pt').input_ids | |
| text_t5_emb = self.text_enc_3(text_t5_ids.to(self.device))[0] | |
| text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.device) | |
| # Merge | |
| clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1) | |
| clip_prompt_emb = torch.nn.functional.pad( | |
| clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1]) | |
| ) | |
| prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2) | |
| pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1) | |
| return prompt_emb, pooled_prompt_emb | |
| def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs): | |
| H, W = img_size | |
| lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor | |
| lC = self.transformer.config.in_channels | |
| latent_shape = (batch_size, lC, lH, lW) | |
| z = torch.randn(latent_shape, device=self.device, dtype=self.dtype) | |
| return z | |
| def encode(self, image: torch.Tensor) -> torch.Tensor: | |
| z = self.vae.encode(image).latent_dist.sample() | |
| z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor | |
| return z | |
| def decode(self, z: torch.Tensor) -> torch.Tensor: | |
| z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
| return self.vae.decode(z, return_dict=False)[0] | |
| class SD3Euler(StableDiffusion3Base): | |
| def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'): | |
| super().__init__(model_key=model_key, device=device) | |
| def inversion(self, src_img, prompts: List[str], NFE:int, cfg_scale: float=1.0, batch_size: int=1): | |
| # encode text prompts | |
| prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) | |
| null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) | |
| # initialize latent | |
| src_img = src_img.to(device=self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| z = self.encode(src_img) | |
| z0 = z.clone() | |
| # timesteps (default option. You can make your custom here.) | |
| self.scheduler.set_timesteps(NFE, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| timesteps = torch.cat([timesteps, torch.zeros(1, device=self.device)]) | |
| timesteps = reversed(timesteps) | |
| sigmas = timesteps / self.scheduler.config.num_train_timesteps | |
| # Solve ODE | |
| pbar = tqdm(timesteps[:-1], total=NFE, desc='SD3 Euler Inversion') | |
| for i, t in enumerate(pbar): | |
| timestep = t.expand(z.shape[0]).to(self.device) | |
| pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) | |
| if cfg_scale != 1.0: | |
| pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) | |
| else: | |
| pred_null_v = 0.0 | |
| sigma = sigmas[i] | |
| sigma_next = sigmas[i+1] | |
| z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v)) | |
| return z | |
| def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, cfg_scale: float=1.0, batch_size: int = 1, latent:Optional[torch.Tensor]=None): | |
| imgH, imgW = img_shape if img_shape is not None else (512, 512) | |
| # encode text prompts | |
| with torch.no_grad(): | |
| prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) | |
| null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) | |
| # initialize latent | |
| if latent is None: | |
| z = self.initialize_latent((imgH, imgW), batch_size) | |
| else: | |
| z = latent | |
| # timesteps (default option. You can make your custom here.) | |
| self.scheduler.set_timesteps(NFE, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| sigmas = timesteps / self.scheduler.config.num_train_timesteps | |
| # Solve ODE | |
| pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler') | |
| for i, t in enumerate(pbar): | |
| timestep = t.expand(z.shape[0]).to(self.device) | |
| pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) | |
| if cfg_scale != 1.0: | |
| pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) | |
| else: | |
| pred_null_v = 0.0 | |
| sigma = sigmas[i] | |
| sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 | |
| z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v)) | |
| # decode | |
| with torch.no_grad(): | |
| img = self.decode(z) | |
| return img | |
| class OSEDiff_SD3_GEN(torch.nn.Module): | |
| def __init__(self, args, base_model): | |
| super().__init__() | |
| self.args = args | |
| self.model = base_model | |
| # Add lora to transformer | |
| print('Adding Lora to OSEDiff_SD3_GEN') | |
| self.transformer_gen = copy.deepcopy(self.model.transformer) | |
| self.transformer_gen.to('cuda') | |
| # self.transformer_gen = self.transformer_gen.float() | |
| self.transformer_gen.requires_grad_(False) | |
| self.transformer_gen.train() | |
| self.transformer_gen, hooks = add_mp_hook(self.transformer_gen) | |
| self.hooks = hooks | |
| lora_params, _ = inject_lora(self.transformer_gen, {"AdaLayerNormZero"}, r=args.lora_rank, verbose=True) | |
| # self.lora_params = lora_params | |
| for name, param in self.transformer_gen.named_parameters(): | |
| if "lora_" in name: | |
| param.requires_grad = True # LoRA up/down | |
| else: | |
| param.requires_grad = False # everything else | |
| # Insert LoRA into VAE | |
| print("Adding Lora to VAE") | |
| self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=True) | |
| def predict_vector(self, z, t, prompt_emb, pooled_emb): | |
| v = self.transformer_gen(hidden_states=z, | |
| timestep=t, | |
| pooled_projections=pooled_emb, | |
| encoder_hidden_states=prompt_emb, | |
| return_dict=False)[0] | |
| return v | |
| def forward(self, x_src, batch=None, args=None): | |
| z_src = self.model.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)) | |
| z_src = z_src.to(self.transformer_gen.device) | |
| # calculate prompt_embeddings and neg_prompt_embeddings | |
| batch_size, _, _, _ = x_src.shape | |
| with torch.no_grad(): | |
| prompt_embeds, pooled_embeds = self.model.encode_prompt(batch["prompt"], batch_size) | |
| neg_prompt_embeds, neg_pooled_embeds = self.model.encode_prompt(batch["neg_prompt"], batch_size) | |
| NFE = 1 | |
| self.model.scheduler.set_timesteps(NFE, device=self.model.device) | |
| timesteps = self.model.scheduler.timesteps | |
| sigmas = timesteps / self.model.scheduler.config.num_train_timesteps | |
| sigmas = sigmas.to(self.transformer_gen.device) | |
| # Solve ODE | |
| i = 0 | |
| t = timesteps[0] | |
| timestep = t.expand(z_src.shape[0]).to(self.transformer_gen.device) | |
| prompt_embeds = prompt_embeds.to(self.transformer_gen.device, dtype=torch.float32) | |
| pooled_embeds = pooled_embeds.to(self.transformer_gen.device, dtype=torch.float32) | |
| pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds) | |
| pred_null_v = 0.0 | |
| sigma = sigmas[i] | |
| sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 | |
| z_src = z_src + (sigma_next - sigma) * (pred_null_v + 1 * (pred_v - pred_null_v)) | |
| output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device)) | |
| return output_image, z_src, prompt_embeds, pooled_embeds | |
| class OSEDiff_SD3_REG(torch.nn.Module): | |
| def __init__(self, args, base_model): | |
| super().__init__() | |
| self.args = args | |
| self.model = base_model | |
| self.transformer_org = self.model.transformer | |
| # Add lora to transformer | |
| print('Adding Lora to OSEDiff_SD3_REG') | |
| self.transformer_reg = copy.deepcopy(self.transformer_org) | |
| self.transformer_reg.to('cuda') | |
| self.transformer_reg.requires_grad_(False) | |
| self.transformer_reg.train() | |
| self.transformer_reg, hooks = add_mp_hook(self.transformer_reg) | |
| self.hooks = hooks | |
| lora_params, _ = inject_lora(self.transformer_reg, {"AdaLayerNormZero"}, r=args.lora_rank, verbose=True) | |
| for name, param in self.transformer_reg.named_parameters(): | |
| if "lora_" in name: | |
| param.requires_grad = True # LoRA up/down | |
| else: | |
| param.requires_grad = False # everything else | |
| def predict_vector_reg(self, z, t, prompt_emb, pooled_emb): | |
| v = self.transformer_reg(hidden_states=z, | |
| timestep=t, | |
| pooled_projections=pooled_emb, | |
| encoder_hidden_states=prompt_emb, | |
| return_dict=False)[0] | |
| return v | |
| def predict_vector_org(self, z, t, prompt_emb, pooled_emb): | |
| v = self.transformer_org(hidden_states=z, | |
| timestep=t, | |
| pooled_projections=pooled_emb, | |
| encoder_hidden_states=prompt_emb, | |
| return_dict=False)[0] | |
| return v | |
| def distribution_matching_loss(self, z0, prompt_embeds, pooled_embeds, global_step, args): | |
| with torch.no_grad(): | |
| device = self.transformer_reg.device | |
| # get timesteps and sigma | |
| u = compute_density_for_timestep_sampling( | |
| weighting_scheme="uniform", | |
| batch_size=1, | |
| logit_mean=0.0, | |
| logit_std=1.0, | |
| mode_scale=1.29, | |
| ) | |
| t_idx = (u*1000).long().to(device) | |
| self.model.scheduler.set_timesteps(1000, device=device) | |
| times = self.model.scheduler.timesteps | |
| t = times[t_idx] | |
| sigma = t / 1000 | |
| # get noise and xt | |
| z0 = z0.to(device) | |
| noise = torch.randn_like(z0) | |
| sigma = sigma.half() | |
| zt = (1-sigma) * z0 + sigma * noise | |
| # Get x0_prediction of transformer_reg | |
| v_pred_reg = self.predict_vector_reg(zt, t, prompt_embeds.to(device), pooled_embeds.to(device)) | |
| reg_model_pred = v_pred_reg * (-sigma) + zt # this is x0_prediction for reg | |
| # Get x0_prediction of transformer_org | |
| org_device = self.transformer_org.device | |
| v_pred_org = self.predict_vector_org(zt.to(org_device), t.to(org_device), prompt_embeds.to(org_device), pooled_embeds.to(org_device)) | |
| org_model_pred = v_pred_org * (-sigma.to(org_device)) + zt.to(org_device) # this is x0_prediction for org | |
| # Visualization | |
| if global_step % 100 == 1: | |
| self.vsd_visualization(z0, noise, zt, reg_model_pred, org_model_pred, global_step, args) | |
| weighting_factor = torch.abs(z0 - org_model_pred.to(device)).mean(dim=[1, 2, 3], keepdim=True) | |
| grad = (reg_model_pred - org_model_pred.to(device)) / weighting_factor | |
| loss = F.mse_loss(z0, (z0 - grad).detach()) | |
| return loss | |
| def vsd_visualization(self, z0, noise, zt, reg_model_pred, org_model_pred, global_step, args): | |
| #-------- Visualization --------# | |
| # 1. Visualize latents, noise, zt | |
| z0_img = self.model.decode(z0.to(dtype=torch.float32, device=self.model.vae.device)) | |
| ns_img = self.model.decode(noise.to(dtype=torch.float32, device=self.model.vae.device)) | |
| zt_img = self.model.decode(zt.to(dtype=torch.float32, device=self.model.vae.device)) | |
| z0_img_pil = transforms.ToPILImage()(torch.clamp(z0_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) | |
| ns_img_pil = transforms.ToPILImage()(torch.clamp(ns_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) | |
| zt_img_pil = transforms.ToPILImage()(torch.clamp(zt_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) | |
| # 2. Visualize reg_img, org_img | |
| reg_img = self.model.decode(reg_model_pred.to(dtype=torch.float32, device=self.model.vae.device)) | |
| org_img = self.model.decode(org_model_pred.to(dtype=torch.float32, device=self.model.vae.device)) | |
| reg_img_pil = transforms.ToPILImage()(torch.clamp(reg_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) | |
| org_img_pil = transforms.ToPILImage()(torch.clamp(org_img[0].cpu(), -1.0, 1.0) * 0.5 + 0.5) | |
| # Concatenate images side by side | |
| w, h = z0_img_pil.width, z0_img_pil.height | |
| combined_image = Image.new('RGB', (w*5, h)) | |
| combined_image.paste(z0_img_pil, (0, 0)) | |
| combined_image.paste(ns_img_pil, (w, 0)) | |
| combined_image.paste(zt_img_pil, (w*2, 0)) | |
| combined_image.paste(reg_img_pil, (w*3, 0)) | |
| combined_image.paste(org_img_pil, (w*4, 0)) | |
| combined_image.save(os.path.join(args.output_dir, f'visualization/vsd/{global_step}.png')) | |
| #-------- Visualization --------# | |
| def diff_loss(self, z0, prompt_embeds, pooled_embeds, net_lpips, args): | |
| device = self.transformer_reg.device | |
| u = compute_density_for_timestep_sampling( | |
| weighting_scheme="uniform", | |
| batch_size=1, | |
| logit_mean=0.0, | |
| logit_std=1.0, | |
| mode_scale=1.29, | |
| ) | |
| t_idx = (u*1000).long().to(device) | |
| self.model.scheduler.set_timesteps(1000, device=device) | |
| times = self.model.scheduler.timesteps | |
| t = times[t_idx] | |
| sigma = t / 1000 | |
| z0 = z0.to(device) | |
| z0, prompt_embeds = z0.detach(), prompt_embeds.detach() | |
| noise = torch.randn_like(z0) | |
| sigma = sigma.half() | |
| zt = (1-sigma) * z0 + sigma * noise # noisy latents | |
| # v-prediction | |
| v_pred = self.predict_vector_reg(zt, t, prompt_embeds.to(device), pooled_embeds.to(device)) | |
| model_pred = v_pred * (-sigma) + zt | |
| target = z0 | |
| loss_weight = compute_loss_weighting_for_sd3("logit_normal", sigma) | |
| diffusion_loss = loss_weight.float() * F.mse_loss(model_pred.float(), target.float()) | |
| loss_d = diffusion_loss | |
| return loss_d.mean() | |
| class OSEDiff_SD3_TEST(torch.nn.Module): | |
| def __init__(self, args, base_model): | |
| super().__init__() | |
| self.args = args | |
| self.model = base_model | |
| self.lora_path = args.lora_path | |
| self.vae_path = args.vae_path | |
| # Add lora to transformer | |
| print(f'Loading LoRA to Transformer from {self.lora_path}') | |
| self.model.transformer.requires_grad_(False) | |
| lora_params, _ = inject_lora(self.model.transformer, {"AdaLayerNormZero"}, loras=self.lora_path, r=args.lora_rank, verbose=False) | |
| for name, param in self.model.transformer.named_parameters(): | |
| param.requires_grad = False | |
| # Insert LoRA into VAE | |
| print(f"Loading LoRA to VAE from {self.vae_path}") | |
| self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=False) | |
| encoder_state_dict_fp16 = torch.load(self.vae_path, map_location="cpu") | |
| self.model.vae.encoder.load_state_dict(encoder_state_dict_fp16) | |
| def predict_vector(self, z, t, prompt_emb, pooled_emb): | |
| v = self.model.transformer(hidden_states=z, | |
| timestep=t, | |
| pooled_projections=pooled_emb, | |
| encoder_hidden_states=prompt_emb, | |
| return_dict=False)[0] | |
| return v | |
| def forward(self, x_src, prompt): | |
| z_src = self.model.vae.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)).latent_dist.sample() * self.model.vae.config.scaling_factor | |
| z_src = z_src.to(self.model.transformer.device) | |
| # calculate prompt_embeddings and neg_prompt_embeddings | |
| batch_size, _, _, _ = x_src.shape | |
| with torch.no_grad(): | |
| prompt_embeds, pooled_embeds = self.model.encode_prompt([prompt], batch_size) | |
| self.model.scheduler.set_timesteps(1, device=self.model.device) | |
| timesteps = self.model.scheduler.timesteps | |
| # Solve ODE | |
| t = timesteps[0] | |
| timestep = t.expand(z_src.shape[0]).to(self.model.transformer.device) | |
| prompt_embeds = prompt_embeds.to(self.model.transformer.device, dtype=torch.float32) | |
| pooled_embeds = pooled_embeds.to(self.model.transformer.device, dtype=torch.float32) | |
| pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds) | |
| z_src = z_src - pred_v | |
| with torch.no_grad(): | |
| output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device)) | |
| return output_image | |
| class OSEDiff_SD3_TEST_efficient(torch.nn.Module): | |
| def __init__(self, args, base_model): | |
| super().__init__() | |
| self.args = args | |
| self.model = base_model | |
| self.lora_path = args.lora_path | |
| self.vae_path = args.vae_path | |
| # Add lora to transformer | |
| print(f'Loading LoRA to Transformer from {self.lora_path}') | |
| self.model.transformer.requires_grad_(False) | |
| lora_params, _ = inject_lora(self.model.transformer, {"AdaLayerNormZero"}, loras=self.lora_path, r=args.lora_rank, verbose=False) | |
| for name, param in self.model.transformer.named_parameters(): | |
| param.requires_grad = False | |
| # Insert LoRA into VAE | |
| print(f"Loading LoRA to VAE from {self.vae_path}") | |
| self.model.vae, self.lora_vae_modules_encoder = inject_lora_vae(self.model.vae, lora_rank=args.lora_rank, verbose=False) | |
| encoder_state_dict_fp16 = torch.load(self.vae_path, map_location="cpu") | |
| self.model.vae.encoder.load_state_dict(encoder_state_dict_fp16) | |
| def predict_vector(self, z, t, prompt_emb, pooled_emb): | |
| v = self.model.transformer(hidden_states=z, | |
| timestep=t, | |
| pooled_projections=pooled_emb, | |
| encoder_hidden_states=prompt_emb, | |
| return_dict=False)[0] | |
| return v | |
| def forward(self, x_src, prompt): | |
| z_src = self.model.vae.encode(x_src.to(dtype=torch.float32, device=self.model.vae.device)).latent_dist.sample() * self.model.vae.config.scaling_factor | |
| z_src = z_src.to(self.model.transformer.device) | |
| # calculate prompt_embeddings | |
| batch_size, _, _, _ = x_src.shape | |
| prompt_embeds, pooled_embeds = self.model.encode_prompt([prompt], batch_size) | |
| self.model.scheduler.set_timesteps(1, device=self.model.device) | |
| timesteps = self.model.scheduler.timesteps | |
| # Solve ODE | |
| t = timesteps[0] | |
| timestep = t.expand(z_src.shape[0]).to(self.model.transformer.device) | |
| prompt_embeds = prompt_embeds.to(self.model.transformer.device, dtype=torch.float32) | |
| pooled_embeds = pooled_embeds.to(self.model.transformer.device, dtype=torch.float32) | |
| pred_v = self.predict_vector(z_src, timestep, prompt_embeds, pooled_embeds) | |
| z_src = z_src - pred_v | |
| output_image = self.model.decode(z_src.to(dtype=torch.float32, device=self.model.vae.device)) | |
| return output_image | |