import os import random from dataclasses import dataclass, field import torch import torch.nn.functional as F from diffusers import DDPMScheduler, UNet2DConditionModel from diffusers.models import AutoencoderKL from diffusers.training_utils import compute_snr from einops import rearrange from omegaconf import OmegaConf from PIL import Image from ..pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler from ..utils.core import find from ..utils.typing import * from .base import BaseSystem from .utils import encode_prompt, vae_encode def compute_embeddings( prompt_batch, empty_prompt_indices, text_encoders, tokenizers, is_train=True, **kwargs, ): original_size = kwargs["original_size"] target_size = kwargs["target_size"] crops_coords_top_left = kwargs["crops_coords_top_left"] for i in range(empty_prompt_indices.shape[0]): if empty_prompt_indices[i]: prompt_batch[i] = "" prompt_embeds, pooled_prompt_embeds = encode_prompt( prompt_batch, text_encoders, tokenizers, 0, is_train ) add_text_embeds = pooled_prompt_embeds.to( device=prompt_embeds.device, dtype=prompt_embeds.dtype ) # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) add_time_ids = add_time_ids.to( device=prompt_embeds.device, dtype=prompt_embeds.dtype ) unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} class IG2MVSDXLSystem(BaseSystem): @dataclass class Config(BaseSystem.Config): # Model / Adapter pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0" pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix" pretrained_adapter_name_or_path: Optional[str] = None pretrained_unet_name_or_path: Optional[str] = None init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict) use_fp16_vae: bool = True use_fp16_clip: bool = True # Training trainable_modules: List[str] = field(default_factory=list) train_cond_encoder: bool = True prompt_drop_prob: float = 0.0 image_drop_prob: float = 0.0 cond_drop_prob: float = 0.0 gradient_checkpointing: bool = False # Noise sampler noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict) noise_offset: float = 0.0 input_perturbation: float = 0.0 snr_gamma: Optional[float] = 5.0 prediction_type: Optional[str] = None shift_noise: bool = False shift_noise_mode: str = "interpolated" shift_noise_scale: float = 1.0 # Evaluation eval_seed: int = 0 eval_num_inference_steps: int = 30 eval_guidance_scale: float = 1.0 eval_height: int = 512 eval_width: int = 512 cfg: Config def configure(self): super().configure() # Prepare pipeline pipeline_kwargs = {} if self.cfg.pretrained_vae_name_or_path is not None: pipeline_kwargs["vae"] = AutoencoderKL.from_pretrained( self.cfg.pretrained_vae_name_or_path ) if self.cfg.pretrained_unet_name_or_path is not None: pipeline_kwargs["unet"] = UNet2DConditionModel.from_pretrained( self.cfg.pretrained_unet_name_or_path ) pipeline: IG2MVSDXLPipeline pipeline = IG2MVSDXLPipeline.from_pretrained( self.cfg.pretrained_model_name_or_path, **pipeline_kwargs ) init_adapter_kwargs = OmegaConf.to_container(self.cfg.init_adapter_kwargs) if "self_attn_processor" in init_adapter_kwargs: self_attn_processor = init_adapter_kwargs["self_attn_processor"] if self_attn_processor is not None and isinstance(self_attn_processor, str): self_attn_processor = find(self_attn_processor) init_adapter_kwargs["self_attn_processor"] = self_attn_processor pipeline.init_custom_adapter(**init_adapter_kwargs) if self.cfg.pretrained_adapter_name_or_path: pretrained_path = os.path.dirname(self.cfg.pretrained_adapter_name_or_path) adapter_name = os.path.basename(self.cfg.pretrained_adapter_name_or_path) pipeline.load_custom_adapter(pretrained_path, weight_name=adapter_name) noise_scheduler = DDPMScheduler.from_config( pipeline.scheduler.config, **self.cfg.noise_scheduler_kwargs ) if self.cfg.shift_noise: noise_scheduler = ShiftSNRScheduler.from_scheduler( noise_scheduler, shift_mode=self.cfg.shift_noise_mode, shift_scale=self.cfg.shift_noise_scale, scheduler_class=DDPMScheduler, ) pipeline.scheduler = noise_scheduler # Prepare models self.pipeline: IG2MVSDXLPipeline = pipeline self.vae = self.pipeline.vae.to( dtype=torch.float16 if self.cfg.use_fp16_vae else torch.float32 ) self.tokenizer = self.pipeline.tokenizer self.tokenizer_2 = self.pipeline.tokenizer_2 self.text_encoder = self.pipeline.text_encoder.to( dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 ) self.text_encoder_2 = self.pipeline.text_encoder_2.to( dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 ) self.feature_extractor = self.pipeline.feature_extractor self.cond_encoder = self.pipeline.cond_encoder self.unet = self.pipeline.unet self.noise_scheduler = self.pipeline.scheduler self.inference_scheduler = DDPMScheduler.from_config( self.noise_scheduler.config ) self.pipeline.scheduler = self.inference_scheduler if self.cfg.prediction_type is not None: self.noise_scheduler.register_to_config( prediction_type=self.cfg.prediction_type ) # Prepare trainable / non-trainable modules trainable_modules = self.cfg.trainable_modules if trainable_modules and len(trainable_modules) > 0: self.unet.requires_grad_(False) for name, module in self.unet.named_modules(): for trainable_module in trainable_modules: if trainable_module in name: module.requires_grad_(True) else: self.unet.requires_grad_(True) self.cond_encoder.requires_grad_(self.cfg.train_cond_encoder) self.vae.requires_grad_(False) self.text_encoder.requires_grad_(False) self.text_encoder_2.requires_grad_(False) # Others # Prepare gradient checkpointing if self.cfg.gradient_checkpointing: self.unet.enable_gradient_checkpointing() def forward( self, noisy_latents: Tensor, conditioning_pixel_values: Tensor, timesteps: Tensor, ref_latents: Tensor, prompts: List[str], num_views: int, **kwargs, ) -> Dict[str, Any]: bsz = noisy_latents.shape[0] b_samples = bsz // num_views num_batch_images = num_views prompt_drop_mask = ( torch.rand(b_samples, device=noisy_latents.device) < self.cfg.prompt_drop_prob ) image_drop_mask = ( torch.rand(b_samples, device=noisy_latents.device) < self.cfg.image_drop_prob ) cond_drop_mask = ( torch.rand(b_samples, device=noisy_latents.device) < self.cfg.cond_drop_prob ) prompt_drop_mask = prompt_drop_mask | cond_drop_mask image_drop_mask = image_drop_mask | cond_drop_mask with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # Here, we compute not just the text embeddings but also the additional embeddings # needed for the SD XL UNet to operate. additional_embeds = compute_embeddings( prompts, prompt_drop_mask, [self.text_encoder, self.text_encoder_2], [self.tokenizer, self.tokenizer_2], **kwargs, ) # Process reference latents to obtain reference features with torch.no_grad(): ref_timesteps = torch.zeros_like(timesteps[:b_samples]) ref_hidden_states = {} self.unet( ref_latents, ref_timesteps, encoder_hidden_states=additional_embeds["prompt_embeds"], added_cond_kwargs={ "text_embeds": additional_embeds["text_embeds"], "time_ids": additional_embeds["time_ids"], }, cross_attention_kwargs={ "cache_hidden_states": ref_hidden_states, "use_mv": False, "use_ref": False, }, return_dict=False, ) for k, v in ref_hidden_states.items(): v_ = v v_[image_drop_mask] = 0.0 ref_hidden_states[k] = v_.repeat_interleave(num_batch_images, dim=0) # Repeat additional embeddings for each image in the batch for key, value in additional_embeds.items(): kwargs[key] = value.repeat_interleave(num_batch_images, dim=0) conditioning_features = self.cond_encoder(conditioning_pixel_values) added_cond_kwargs = { "text_embeds": kwargs["text_embeds"], "time_ids": kwargs["time_ids"], } noise_pred = self.unet( noisy_latents, timesteps, encoder_hidden_states=kwargs["prompt_embeds"], added_cond_kwargs=added_cond_kwargs, down_intrablock_additional_residuals=conditioning_features, cross_attention_kwargs={ "ref_hidden_states": ref_hidden_states, "num_views": num_views, }, ).sample return {"noise_pred": noise_pred} def training_step(self, batch, batch_idx): num_views = batch["num_views"] vae_max_slice = 8 with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): latents = [] for i in range(0, batch["rgb"].shape[0], vae_max_slice): latents.append( vae_encode( self.vae, batch["rgb"][i : i + vae_max_slice].to(self.vae.dtype) * 2 - 1, sample=True, apply_scale=True, ).float() ) latents = torch.cat(latents, dim=0) with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): ref_latents = vae_encode( self.vae, batch["reference_rgb"].to(self.vae.dtype) * 2 - 1, sample=True, apply_scale=True, ).float() bsz = latents.shape[0] b_samples = bsz // num_views noise = torch.randn_like(latents) if self.cfg.noise_offset is not None: # # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += self.cfg.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) noise_mask = ( batch["noise_mask"] if "noise_mask" in batch else torch.ones((bsz,), dtype=torch.bool, device=latents.device) ) timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (b_samples,), device=latents.device, dtype=torch.long, ) timesteps = timesteps.repeat_interleave(num_views) timesteps[~noise_mask] = 0 if self.cfg.input_perturbation is not None: new_noise = noise + self.cfg.input_perturbation * torch.randn_like(noise) noisy_latents = self.noise_scheduler.add_noise( latents, new_noise, timesteps ) else: noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents[~noise_mask] = latents[~noise_mask] if self.noise_scheduler.config.prediction_type == "epsilon": target = noise elif self.noise_scheduler.config.prediction_type == "v_prediction": target = self.noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError( f"Unsupported prediction type {self.noise_scheduler.config.prediction_type}" ) model_pred = self( noisy_latents, batch["source_rgb"], timesteps, ref_latents, **batch )["noise_pred"] model_pred = model_pred[noise_mask] target = target[noise_mask] if self.cfg.snr_gamma is None: loss = F.mse_loss(model_pred, target, reduction="mean") else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(self.noise_scheduler, timesteps) if self.noise_scheduler.config.prediction_type == "v_prediction": # Velocity objective requires that we add one to SNR values before we divide by them. snr = snr + 1 mse_loss_weights = ( torch.stack( [snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 ).min(dim=1)[0] / snr ) loss = F.mse_loss(model_pred, target, reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() self.log("train/loss", loss, prog_bar=True) # will execute self.on_check_train every self.cfg.check_train_every_n_steps steps self.check_train(batch) return {"loss": loss} def on_train_batch_end(self, outputs, batch, batch_idx): pass def get_input_visualizations(self, batch): return [ { "type": "rgb", "img": rearrange( batch["source_rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"], ), "kwargs": {"data_format": "HWC"}, }, { "type": "rgb", "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), "kwargs": {"data_format": "HWC"}, }, { "type": "rgb", "img": rearrange( batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] ), "kwargs": {"data_format": "HWC"}, }, ] def get_output_visualizations(self, batch, outputs): images = [ { "type": "rgb", "img": rearrange( batch["source_rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"], ), "kwargs": {"data_format": "HWC"}, }, { "type": "rgb", "img": rearrange( batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] ), "kwargs": {"data_format": "HWC"}, }, { "type": "rgb", "img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), "kwargs": {"data_format": "HWC"}, }, { "type": "rgb", "img": rearrange( outputs, "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] ), "kwargs": {"data_format": "HWC"}, }, ] return images def generate_images(self, batch, **kwargs): return self.pipeline( prompt=batch["prompts"], control_image=batch["source_rgb"], num_images_per_prompt=batch["num_views"], generator=torch.Generator(device=self.device).manual_seed( self.cfg.eval_seed ), num_inference_steps=self.cfg.eval_num_inference_steps, guidance_scale=self.cfg.eval_guidance_scale, height=self.cfg.eval_height, width=self.cfg.eval_width, reference_image=batch["reference_rgb"], output_type="pt", ).images def on_save_checkpoint(self, checkpoint): if self.global_rank == 0: self.pipeline.save_custom_adapter( os.path.dirname(self.get_save_dir()), "step1x-3d-ig2v.safetensors", safe_serialization=True, include_keys=self.cfg.trainable_modules, ) def on_check_train(self, batch): self.save_image_grid( f"it{self.true_global_step}-train.jpg", self.get_input_visualizations(batch), name="train_step_input", step=self.true_global_step, ) def validation_step(self, batch, batch_idx): out = self.generate_images(batch) if ( self.cfg.check_val_limit_rank > 0 and self.global_rank < self.cfg.check_val_limit_rank ): self.save_image_grid( f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg", self.get_output_visualizations(batch, out), name=f"validation_step_output_{self.global_rank}_{batch_idx}", step=self.true_global_step, ) def on_validation_epoch_end(self): pass def test_step(self, batch, batch_idx): out = self.generate_images(batch) self.save_image_grid( f"it{self.true_global_step}-test-{self.global_rank}_{batch_idx}.jpg", self.get_output_visualizations(batch, out), name=f"test_step_output_{self.global_rank}_{batch_idx}", step=self.true_global_step, ) def on_test_end(self): pass