|
import os |
|
import random |
|
from dataclasses import dataclass, field |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from diffusers import DDPMScheduler, UNet2DConditionModel |
|
from diffusers.models import AutoencoderKL |
|
from diffusers.training_utils import compute_snr |
|
from einops import rearrange |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
|
|
from ..pipelines.ig2mv_sdxl_pipeline import IG2MVSDXLPipeline |
|
from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler |
|
from ..utils.core import find |
|
from ..utils.typing import * |
|
from .base import BaseSystem |
|
from .utils import encode_prompt, vae_encode |
|
|
|
|
|
def compute_embeddings( |
|
prompt_batch, |
|
empty_prompt_indices, |
|
text_encoders, |
|
tokenizers, |
|
is_train=True, |
|
**kwargs, |
|
): |
|
original_size = kwargs["original_size"] |
|
target_size = kwargs["target_size"] |
|
crops_coords_top_left = kwargs["crops_coords_top_left"] |
|
|
|
for i in range(empty_prompt_indices.shape[0]): |
|
if empty_prompt_indices[i]: |
|
prompt_batch[i] = "" |
|
|
|
prompt_embeds, pooled_prompt_embeds = encode_prompt( |
|
prompt_batch, text_encoders, tokenizers, 0, is_train |
|
) |
|
add_text_embeds = pooled_prompt_embeds.to( |
|
device=prompt_embeds.device, dtype=prompt_embeds.dtype |
|
) |
|
|
|
|
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
add_time_ids = torch.tensor([add_time_ids]) |
|
add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) |
|
add_time_ids = add_time_ids.to( |
|
device=prompt_embeds.device, dtype=prompt_embeds.dtype |
|
) |
|
|
|
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
|
|
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} |
|
|
|
|
|
class IG2MVSDXLSystem(BaseSystem): |
|
@dataclass |
|
class Config(BaseSystem.Config): |
|
|
|
|
|
pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0" |
|
pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix" |
|
pretrained_adapter_name_or_path: Optional[str] = None |
|
pretrained_unet_name_or_path: Optional[str] = None |
|
init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict) |
|
|
|
use_fp16_vae: bool = True |
|
use_fp16_clip: bool = True |
|
|
|
|
|
trainable_modules: List[str] = field(default_factory=list) |
|
train_cond_encoder: bool = True |
|
prompt_drop_prob: float = 0.0 |
|
image_drop_prob: float = 0.0 |
|
cond_drop_prob: float = 0.0 |
|
|
|
gradient_checkpointing: bool = False |
|
|
|
|
|
noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict) |
|
noise_offset: float = 0.0 |
|
input_perturbation: float = 0.0 |
|
snr_gamma: Optional[float] = 5.0 |
|
prediction_type: Optional[str] = None |
|
shift_noise: bool = False |
|
shift_noise_mode: str = "interpolated" |
|
shift_noise_scale: float = 1.0 |
|
|
|
|
|
eval_seed: int = 0 |
|
eval_num_inference_steps: int = 30 |
|
eval_guidance_scale: float = 1.0 |
|
eval_height: int = 512 |
|
eval_width: int = 512 |
|
|
|
cfg: Config |
|
|
|
def configure(self): |
|
super().configure() |
|
|
|
|
|
pipeline_kwargs = {} |
|
if self.cfg.pretrained_vae_name_or_path is not None: |
|
pipeline_kwargs["vae"] = AutoencoderKL.from_pretrained( |
|
self.cfg.pretrained_vae_name_or_path |
|
) |
|
if self.cfg.pretrained_unet_name_or_path is not None: |
|
pipeline_kwargs["unet"] = UNet2DConditionModel.from_pretrained( |
|
self.cfg.pretrained_unet_name_or_path |
|
) |
|
|
|
pipeline: IG2MVSDXLPipeline |
|
pipeline = IG2MVSDXLPipeline.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, **pipeline_kwargs |
|
) |
|
|
|
init_adapter_kwargs = OmegaConf.to_container(self.cfg.init_adapter_kwargs) |
|
if "self_attn_processor" in init_adapter_kwargs: |
|
self_attn_processor = init_adapter_kwargs["self_attn_processor"] |
|
if self_attn_processor is not None and isinstance(self_attn_processor, str): |
|
self_attn_processor = find(self_attn_processor) |
|
init_adapter_kwargs["self_attn_processor"] = self_attn_processor |
|
pipeline.init_custom_adapter(**init_adapter_kwargs) |
|
|
|
if self.cfg.pretrained_adapter_name_or_path: |
|
pretrained_path = os.path.dirname(self.cfg.pretrained_adapter_name_or_path) |
|
adapter_name = os.path.basename(self.cfg.pretrained_adapter_name_or_path) |
|
pipeline.load_custom_adapter(pretrained_path, weight_name=adapter_name) |
|
|
|
noise_scheduler = DDPMScheduler.from_config( |
|
pipeline.scheduler.config, **self.cfg.noise_scheduler_kwargs |
|
) |
|
if self.cfg.shift_noise: |
|
noise_scheduler = ShiftSNRScheduler.from_scheduler( |
|
noise_scheduler, |
|
shift_mode=self.cfg.shift_noise_mode, |
|
shift_scale=self.cfg.shift_noise_scale, |
|
scheduler_class=DDPMScheduler, |
|
) |
|
pipeline.scheduler = noise_scheduler |
|
|
|
|
|
self.pipeline: IG2MVSDXLPipeline = pipeline |
|
self.vae = self.pipeline.vae.to( |
|
dtype=torch.float16 if self.cfg.use_fp16_vae else torch.float32 |
|
) |
|
self.tokenizer = self.pipeline.tokenizer |
|
self.tokenizer_2 = self.pipeline.tokenizer_2 |
|
self.text_encoder = self.pipeline.text_encoder.to( |
|
dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 |
|
) |
|
self.text_encoder_2 = self.pipeline.text_encoder_2.to( |
|
dtype=torch.float16 if self.cfg.use_fp16_clip else torch.float32 |
|
) |
|
self.feature_extractor = self.pipeline.feature_extractor |
|
|
|
self.cond_encoder = self.pipeline.cond_encoder |
|
self.unet = self.pipeline.unet |
|
self.noise_scheduler = self.pipeline.scheduler |
|
self.inference_scheduler = DDPMScheduler.from_config( |
|
self.noise_scheduler.config |
|
) |
|
self.pipeline.scheduler = self.inference_scheduler |
|
if self.cfg.prediction_type is not None: |
|
self.noise_scheduler.register_to_config( |
|
prediction_type=self.cfg.prediction_type |
|
) |
|
|
|
|
|
trainable_modules = self.cfg.trainable_modules |
|
if trainable_modules and len(trainable_modules) > 0: |
|
self.unet.requires_grad_(False) |
|
for name, module in self.unet.named_modules(): |
|
for trainable_module in trainable_modules: |
|
if trainable_module in name: |
|
module.requires_grad_(True) |
|
else: |
|
self.unet.requires_grad_(True) |
|
self.cond_encoder.requires_grad_(self.cfg.train_cond_encoder) |
|
|
|
self.vae.requires_grad_(False) |
|
self.text_encoder.requires_grad_(False) |
|
self.text_encoder_2.requires_grad_(False) |
|
|
|
|
|
|
|
if self.cfg.gradient_checkpointing: |
|
self.unet.enable_gradient_checkpointing() |
|
|
|
def forward( |
|
self, |
|
noisy_latents: Tensor, |
|
conditioning_pixel_values: Tensor, |
|
timesteps: Tensor, |
|
ref_latents: Tensor, |
|
prompts: List[str], |
|
num_views: int, |
|
**kwargs, |
|
) -> Dict[str, Any]: |
|
bsz = noisy_latents.shape[0] |
|
b_samples = bsz // num_views |
|
num_batch_images = num_views |
|
|
|
prompt_drop_mask = ( |
|
torch.rand(b_samples, device=noisy_latents.device) |
|
< self.cfg.prompt_drop_prob |
|
) |
|
image_drop_mask = ( |
|
torch.rand(b_samples, device=noisy_latents.device) |
|
< self.cfg.image_drop_prob |
|
) |
|
cond_drop_mask = ( |
|
torch.rand(b_samples, device=noisy_latents.device) < self.cfg.cond_drop_prob |
|
) |
|
prompt_drop_mask = prompt_drop_mask | cond_drop_mask |
|
image_drop_mask = image_drop_mask | cond_drop_mask |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): |
|
|
|
|
|
additional_embeds = compute_embeddings( |
|
prompts, |
|
prompt_drop_mask, |
|
[self.text_encoder, self.text_encoder_2], |
|
[self.tokenizer, self.tokenizer_2], |
|
**kwargs, |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
ref_timesteps = torch.zeros_like(timesteps[:b_samples]) |
|
ref_hidden_states = {} |
|
self.unet( |
|
ref_latents, |
|
ref_timesteps, |
|
encoder_hidden_states=additional_embeds["prompt_embeds"], |
|
added_cond_kwargs={ |
|
"text_embeds": additional_embeds["text_embeds"], |
|
"time_ids": additional_embeds["time_ids"], |
|
}, |
|
cross_attention_kwargs={ |
|
"cache_hidden_states": ref_hidden_states, |
|
"use_mv": False, |
|
"use_ref": False, |
|
}, |
|
return_dict=False, |
|
) |
|
for k, v in ref_hidden_states.items(): |
|
v_ = v |
|
v_[image_drop_mask] = 0.0 |
|
ref_hidden_states[k] = v_.repeat_interleave(num_batch_images, dim=0) |
|
|
|
|
|
for key, value in additional_embeds.items(): |
|
kwargs[key] = value.repeat_interleave(num_batch_images, dim=0) |
|
|
|
conditioning_features = self.cond_encoder(conditioning_pixel_values) |
|
|
|
added_cond_kwargs = { |
|
"text_embeds": kwargs["text_embeds"], |
|
"time_ids": kwargs["time_ids"], |
|
} |
|
|
|
noise_pred = self.unet( |
|
noisy_latents, |
|
timesteps, |
|
encoder_hidden_states=kwargs["prompt_embeds"], |
|
added_cond_kwargs=added_cond_kwargs, |
|
down_intrablock_additional_residuals=conditioning_features, |
|
cross_attention_kwargs={ |
|
"ref_hidden_states": ref_hidden_states, |
|
"num_views": num_views, |
|
}, |
|
).sample |
|
|
|
return {"noise_pred": noise_pred} |
|
|
|
def training_step(self, batch, batch_idx): |
|
num_views = batch["num_views"] |
|
|
|
vae_max_slice = 8 |
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): |
|
latents = [] |
|
for i in range(0, batch["rgb"].shape[0], vae_max_slice): |
|
latents.append( |
|
vae_encode( |
|
self.vae, |
|
batch["rgb"][i : i + vae_max_slice].to(self.vae.dtype) * 2 - 1, |
|
sample=True, |
|
apply_scale=True, |
|
).float() |
|
) |
|
latents = torch.cat(latents, dim=0) |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): |
|
ref_latents = vae_encode( |
|
self.vae, |
|
batch["reference_rgb"].to(self.vae.dtype) * 2 - 1, |
|
sample=True, |
|
apply_scale=True, |
|
).float() |
|
|
|
bsz = latents.shape[0] |
|
b_samples = bsz // num_views |
|
|
|
noise = torch.randn_like(latents) |
|
if self.cfg.noise_offset is not None: |
|
|
|
noise += self.cfg.noise_offset * torch.randn( |
|
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device |
|
) |
|
|
|
noise_mask = ( |
|
batch["noise_mask"] |
|
if "noise_mask" in batch |
|
else torch.ones((bsz,), dtype=torch.bool, device=latents.device) |
|
) |
|
timesteps = torch.randint( |
|
0, |
|
self.noise_scheduler.config.num_train_timesteps, |
|
(b_samples,), |
|
device=latents.device, |
|
dtype=torch.long, |
|
) |
|
timesteps = timesteps.repeat_interleave(num_views) |
|
timesteps[~noise_mask] = 0 |
|
|
|
if self.cfg.input_perturbation is not None: |
|
new_noise = noise + self.cfg.input_perturbation * torch.randn_like(noise) |
|
noisy_latents = self.noise_scheduler.add_noise( |
|
latents, new_noise, timesteps |
|
) |
|
else: |
|
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
noisy_latents[~noise_mask] = latents[~noise_mask] |
|
|
|
if self.noise_scheduler.config.prediction_type == "epsilon": |
|
target = noise |
|
elif self.noise_scheduler.config.prediction_type == "v_prediction": |
|
target = self.noise_scheduler.get_velocity(latents, noise, timesteps) |
|
else: |
|
raise ValueError( |
|
f"Unsupported prediction type {self.noise_scheduler.config.prediction_type}" |
|
) |
|
|
|
model_pred = self( |
|
noisy_latents, batch["source_rgb"], timesteps, ref_latents, **batch |
|
)["noise_pred"] |
|
|
|
model_pred = model_pred[noise_mask] |
|
target = target[noise_mask] |
|
|
|
if self.cfg.snr_gamma is None: |
|
loss = F.mse_loss(model_pred, target, reduction="mean") |
|
else: |
|
|
|
|
|
|
|
snr = compute_snr(self.noise_scheduler, timesteps) |
|
if self.noise_scheduler.config.prediction_type == "v_prediction": |
|
|
|
snr = snr + 1 |
|
mse_loss_weights = ( |
|
torch.stack( |
|
[snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 |
|
).min(dim=1)[0] |
|
/ snr |
|
) |
|
|
|
loss = F.mse_loss(model_pred, target, reduction="none") |
|
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights |
|
loss = loss.mean() |
|
|
|
self.log("train/loss", loss, prog_bar=True) |
|
|
|
|
|
self.check_train(batch) |
|
|
|
return {"loss": loss} |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx): |
|
pass |
|
|
|
def get_input_visualizations(self, batch): |
|
return [ |
|
{ |
|
"type": "rgb", |
|
"img": rearrange( |
|
batch["source_rgb"], |
|
"(B N) C H W -> (B H) (N W) C", |
|
N=batch["num_views"], |
|
), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
{ |
|
"type": "rgb", |
|
"img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
{ |
|
"type": "rgb", |
|
"img": rearrange( |
|
batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] |
|
), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
] |
|
|
|
def get_output_visualizations(self, batch, outputs): |
|
images = [ |
|
{ |
|
"type": "rgb", |
|
"img": rearrange( |
|
batch["source_rgb"], |
|
"(B N) C H W -> (B H) (N W) C", |
|
N=batch["num_views"], |
|
), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
{ |
|
"type": "rgb", |
|
"img": rearrange( |
|
batch["rgb"], "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] |
|
), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
{ |
|
"type": "rgb", |
|
"img": rearrange(batch["reference_rgb"], "B C H W -> (B H) W C"), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
{ |
|
"type": "rgb", |
|
"img": rearrange( |
|
outputs, "(B N) C H W -> (B H) (N W) C", N=batch["num_views"] |
|
), |
|
"kwargs": {"data_format": "HWC"}, |
|
}, |
|
] |
|
return images |
|
|
|
def generate_images(self, batch, **kwargs): |
|
return self.pipeline( |
|
prompt=batch["prompts"], |
|
control_image=batch["source_rgb"], |
|
num_images_per_prompt=batch["num_views"], |
|
generator=torch.Generator(device=self.device).manual_seed( |
|
self.cfg.eval_seed |
|
), |
|
num_inference_steps=self.cfg.eval_num_inference_steps, |
|
guidance_scale=self.cfg.eval_guidance_scale, |
|
height=self.cfg.eval_height, |
|
width=self.cfg.eval_width, |
|
reference_image=batch["reference_rgb"], |
|
output_type="pt", |
|
).images |
|
|
|
def on_save_checkpoint(self, checkpoint): |
|
if self.global_rank == 0: |
|
self.pipeline.save_custom_adapter( |
|
os.path.dirname(self.get_save_dir()), |
|
"step1x-3d-ig2v.safetensors", |
|
safe_serialization=True, |
|
include_keys=self.cfg.trainable_modules, |
|
) |
|
|
|
def on_check_train(self, batch): |
|
self.save_image_grid( |
|
f"it{self.true_global_step}-train.jpg", |
|
self.get_input_visualizations(batch), |
|
name="train_step_input", |
|
step=self.true_global_step, |
|
) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
out = self.generate_images(batch) |
|
|
|
if ( |
|
self.cfg.check_val_limit_rank > 0 |
|
and self.global_rank < self.cfg.check_val_limit_rank |
|
): |
|
self.save_image_grid( |
|
f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg", |
|
self.get_output_visualizations(batch, out), |
|
name=f"validation_step_output_{self.global_rank}_{batch_idx}", |
|
step=self.true_global_step, |
|
) |
|
|
|
def on_validation_epoch_end(self): |
|
pass |
|
|
|
def test_step(self, batch, batch_idx): |
|
out = self.generate_images(batch) |
|
|
|
self.save_image_grid( |
|
f"it{self.true_global_step}-test-{self.global_rank}_{batch_idx}.jpg", |
|
self.get_output_visualizations(batch, out), |
|
name=f"test_step_output_{self.global_rank}_{batch_idx}", |
|
step=self.true_global_step, |
|
) |
|
|
|
def on_test_end(self): |
|
pass |
|
|