Spaces:
Running
on
Zero
Running
on
Zero
import lightning as L | |
from diffusers.pipelines import FluxPipeline | |
import torch | |
from peft import LoraConfig, get_peft_model_state_dict | |
import prodigyopt | |
from ..flux.transformer import tranformer_forward | |
from ..flux.condition import Condition | |
from ..flux.pipeline_tools import encode_images, prepare_text_input | |
class OminiModel(L.LightningModule): | |
def __init__( | |
self, | |
flux_pipe_id: str, | |
lora_path: str = None, | |
lora_config: dict = None, | |
device: str = "cuda", | |
dtype: torch.dtype = torch.bfloat16, | |
model_config: dict = {}, | |
optimizer_config: dict = None, | |
gradient_checkpointing: bool = False, | |
): | |
# Initialize the LightningModule | |
super().__init__() | |
self.model_config = model_config | |
self.optimizer_config = optimizer_config | |
# Load the Flux pipeline | |
self.flux_pipe: FluxPipeline = ( | |
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device) | |
) | |
self.transformer = self.flux_pipe.transformer | |
self.transformer.gradient_checkpointing = gradient_checkpointing | |
self.transformer.train() | |
# Freeze the Flux pipeline | |
self.flux_pipe.text_encoder.requires_grad_(False).eval() | |
self.flux_pipe.text_encoder_2.requires_grad_(False).eval() | |
self.flux_pipe.vae.requires_grad_(False).eval() | |
# Initialize LoRA layers | |
self.lora_layers = self.init_lora(lora_path, lora_config) | |
self.to(device).to(dtype) | |
def init_lora(self, lora_path: str, lora_config: dict): | |
assert lora_path or lora_config | |
if lora_path: | |
# TODO: Implement this | |
raise NotImplementedError | |
else: | |
self.transformer.add_adapter(LoraConfig(**lora_config)) | |
# TODO: Check if this is correct (p.requires_grad) | |
lora_layers = filter( | |
lambda p: p.requires_grad, self.transformer.parameters() | |
) | |
return list(lora_layers) | |
def save_lora(self, path: str): | |
FluxPipeline.save_lora_weights( | |
save_directory=path, | |
transformer_lora_layers=get_peft_model_state_dict(self.transformer), | |
safe_serialization=True, | |
) | |
def configure_optimizers(self): | |
# Freeze the transformer | |
self.transformer.requires_grad_(False) | |
opt_config = self.optimizer_config | |
# Set the trainable parameters | |
self.trainable_params = self.lora_layers | |
# Unfreeze trainable parameters | |
for p in self.trainable_params: | |
p.requires_grad_(True) | |
# Initialize the optimizer | |
if opt_config["type"] == "AdamW": | |
optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) | |
elif opt_config["type"] == "Prodigy": | |
optimizer = prodigyopt.Prodigy( | |
self.trainable_params, | |
**opt_config["params"], | |
) | |
elif opt_config["type"] == "SGD": | |
optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) | |
else: | |
raise NotImplementedError | |
return optimizer | |
def training_step(self, batch, batch_idx): | |
step_loss = self.step(batch) | |
self.log_loss = ( | |
step_loss.item() | |
if not hasattr(self, "log_loss") | |
else self.log_loss * 0.95 + step_loss.item() * 0.05 | |
) | |
return step_loss | |
def step(self, batch): | |
imgs = batch["image"] | |
conditions = batch["condition"] | |
condition_types = batch["condition_type"] | |
prompts = batch["description"] | |
position_delta = batch["position_delta"][0] | |
position_scale = float(batch.get("position_scale", [1.0])[0]) | |
# Prepare inputs | |
with torch.no_grad(): | |
# Prepare image input | |
x_0, img_ids = encode_images(self.flux_pipe, imgs) | |
# Prepare text input | |
prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( | |
self.flux_pipe, prompts | |
) | |
# Prepare t and x_t | |
t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) | |
x_1 = torch.randn_like(x_0).to(self.device) | |
t_ = t.unsqueeze(1).unsqueeze(1) | |
x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) | |
# Prepare conditions | |
condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) | |
# Add position delta | |
condition_ids[:, 1] += position_delta[0] | |
condition_ids[:, 2] += position_delta[1] | |
if position_scale != 1.0: | |
scale_bias = (position_scale - 1.0) / 2 | |
condition_ids[:, 1] *= position_scale | |
condition_ids[:, 2] *= position_scale | |
condition_ids[:, 1] += scale_bias | |
condition_ids[:, 2] += scale_bias | |
# Prepare condition type | |
condition_type_ids = torch.tensor( | |
[ | |
Condition.get_type_id(condition_type) | |
for condition_type in condition_types | |
] | |
).to(self.device) | |
condition_type_ids = ( | |
torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] | |
).unsqueeze(1) | |
# Prepare guidance | |
guidance = ( | |
torch.ones_like(t).to(self.device) | |
if self.transformer.config.guidance_embeds | |
else None | |
) | |
# Forward pass | |
transformer_out = tranformer_forward( | |
self.transformer, | |
# Model config | |
model_config=self.model_config, | |
# Inputs of the condition (new feature) | |
condition_latents=condition_latents, | |
condition_ids=condition_ids, | |
condition_type_ids=condition_type_ids, | |
# Inputs to the original transformer | |
hidden_states=x_t, | |
timestep=t, | |
guidance=guidance, | |
pooled_projections=pooled_prompt_embeds, | |
encoder_hidden_states=prompt_embeds, | |
txt_ids=text_ids, | |
img_ids=img_ids, | |
joint_attention_kwargs=None, | |
return_dict=False, | |
) | |
pred = transformer_out[0] | |
# Compute loss | |
loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") | |
self.last_t = t.mean().item() | |
return loss | |