Yuanshi's picture
Upload 61 files
fb6a167 verified
import lightning as L
from diffusers.pipelines import FluxPipeline
import torch
from peft import LoraConfig, get_peft_model_state_dict
import prodigyopt
from ..flux.transformer import tranformer_forward
from ..flux.condition import Condition
from ..flux.pipeline_tools import encode_images, prepare_text_input
class OminiModel(L.LightningModule):
def __init__(
self,
flux_pipe_id: str,
lora_path: str = None,
lora_config: dict = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
model_config: dict = {},
optimizer_config: dict = None,
gradient_checkpointing: bool = False,
):
# Initialize the LightningModule
super().__init__()
self.model_config = model_config
self.optimizer_config = optimizer_config
# Load the Flux pipeline
self.flux_pipe: FluxPipeline = (
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
)
self.transformer = self.flux_pipe.transformer
self.transformer.gradient_checkpointing = gradient_checkpointing
self.transformer.train()
# Freeze the Flux pipeline
self.flux_pipe.text_encoder.requires_grad_(False).eval()
self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
self.flux_pipe.vae.requires_grad_(False).eval()
# Initialize LoRA layers
self.lora_layers = self.init_lora(lora_path, lora_config)
self.to(device).to(dtype)
def init_lora(self, lora_path: str, lora_config: dict):
assert lora_path or lora_config
if lora_path:
# TODO: Implement this
raise NotImplementedError
else:
self.transformer.add_adapter(LoraConfig(**lora_config))
# TODO: Check if this is correct (p.requires_grad)
lora_layers = filter(
lambda p: p.requires_grad, self.transformer.parameters()
)
return list(lora_layers)
def save_lora(self, path: str):
FluxPipeline.save_lora_weights(
save_directory=path,
transformer_lora_layers=get_peft_model_state_dict(self.transformer),
safe_serialization=True,
)
def configure_optimizers(self):
# Freeze the transformer
self.transformer.requires_grad_(False)
opt_config = self.optimizer_config
# Set the trainable parameters
self.trainable_params = self.lora_layers
# Unfreeze trainable parameters
for p in self.trainable_params:
p.requires_grad_(True)
# Initialize the optimizer
if opt_config["type"] == "AdamW":
optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
elif opt_config["type"] == "Prodigy":
optimizer = prodigyopt.Prodigy(
self.trainable_params,
**opt_config["params"],
)
elif opt_config["type"] == "SGD":
optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
else:
raise NotImplementedError
return optimizer
def training_step(self, batch, batch_idx):
step_loss = self.step(batch)
self.log_loss = (
step_loss.item()
if not hasattr(self, "log_loss")
else self.log_loss * 0.95 + step_loss.item() * 0.05
)
return step_loss
def step(self, batch):
imgs = batch["image"]
conditions = batch["condition"]
condition_types = batch["condition_type"]
prompts = batch["description"]
position_delta = batch["position_delta"][0]
position_scale = float(batch.get("position_scale", [1.0])[0])
# Prepare inputs
with torch.no_grad():
# Prepare image input
x_0, img_ids = encode_images(self.flux_pipe, imgs)
# Prepare text input
prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
self.flux_pipe, prompts
)
# Prepare t and x_t
t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
x_1 = torch.randn_like(x_0).to(self.device)
t_ = t.unsqueeze(1).unsqueeze(1)
x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
# Prepare conditions
condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
# Add position delta
condition_ids[:, 1] += position_delta[0]
condition_ids[:, 2] += position_delta[1]
if position_scale != 1.0:
scale_bias = (position_scale - 1.0) / 2
condition_ids[:, 1] *= position_scale
condition_ids[:, 2] *= position_scale
condition_ids[:, 1] += scale_bias
condition_ids[:, 2] += scale_bias
# Prepare condition type
condition_type_ids = torch.tensor(
[
Condition.get_type_id(condition_type)
for condition_type in condition_types
]
).to(self.device)
condition_type_ids = (
torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
).unsqueeze(1)
# Prepare guidance
guidance = (
torch.ones_like(t).to(self.device)
if self.transformer.config.guidance_embeds
else None
)
# Forward pass
transformer_out = tranformer_forward(
self.transformer,
# Model config
model_config=self.model_config,
# Inputs of the condition (new feature)
condition_latents=condition_latents,
condition_ids=condition_ids,
condition_type_ids=condition_type_ids,
# Inputs to the original transformer
hidden_states=x_t,
timestep=t,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=img_ids,
joint_attention_kwargs=None,
return_dict=False,
)
pred = transformer_out[0]
# Compute loss
loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
self.last_t = t.mean().item()
return loss