Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,482 Bytes
fb6a167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import lightning as L
from diffusers.pipelines import FluxPipeline
import torch
from peft import LoraConfig, get_peft_model_state_dict
import prodigyopt
from ..flux.transformer import tranformer_forward
from ..flux.condition import Condition
from ..flux.pipeline_tools import encode_images, prepare_text_input
class OminiModel(L.LightningModule):
def __init__(
self,
flux_pipe_id: str,
lora_path: str = None,
lora_config: dict = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
model_config: dict = {},
optimizer_config: dict = None,
gradient_checkpointing: bool = False,
):
# Initialize the LightningModule
super().__init__()
self.model_config = model_config
self.optimizer_config = optimizer_config
# Load the Flux pipeline
self.flux_pipe: FluxPipeline = (
FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
)
self.transformer = self.flux_pipe.transformer
self.transformer.gradient_checkpointing = gradient_checkpointing
self.transformer.train()
# Freeze the Flux pipeline
self.flux_pipe.text_encoder.requires_grad_(False).eval()
self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
self.flux_pipe.vae.requires_grad_(False).eval()
# Initialize LoRA layers
self.lora_layers = self.init_lora(lora_path, lora_config)
self.to(device).to(dtype)
def init_lora(self, lora_path: str, lora_config: dict):
assert lora_path or lora_config
if lora_path:
# TODO: Implement this
raise NotImplementedError
else:
self.transformer.add_adapter(LoraConfig(**lora_config))
# TODO: Check if this is correct (p.requires_grad)
lora_layers = filter(
lambda p: p.requires_grad, self.transformer.parameters()
)
return list(lora_layers)
def save_lora(self, path: str):
FluxPipeline.save_lora_weights(
save_directory=path,
transformer_lora_layers=get_peft_model_state_dict(self.transformer),
safe_serialization=True,
)
def configure_optimizers(self):
# Freeze the transformer
self.transformer.requires_grad_(False)
opt_config = self.optimizer_config
# Set the trainable parameters
self.trainable_params = self.lora_layers
# Unfreeze trainable parameters
for p in self.trainable_params:
p.requires_grad_(True)
# Initialize the optimizer
if opt_config["type"] == "AdamW":
optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
elif opt_config["type"] == "Prodigy":
optimizer = prodigyopt.Prodigy(
self.trainable_params,
**opt_config["params"],
)
elif opt_config["type"] == "SGD":
optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
else:
raise NotImplementedError
return optimizer
def training_step(self, batch, batch_idx):
step_loss = self.step(batch)
self.log_loss = (
step_loss.item()
if not hasattr(self, "log_loss")
else self.log_loss * 0.95 + step_loss.item() * 0.05
)
return step_loss
def step(self, batch):
imgs = batch["image"]
conditions = batch["condition"]
condition_types = batch["condition_type"]
prompts = batch["description"]
position_delta = batch["position_delta"][0]
position_scale = float(batch.get("position_scale", [1.0])[0])
# Prepare inputs
with torch.no_grad():
# Prepare image input
x_0, img_ids = encode_images(self.flux_pipe, imgs)
# Prepare text input
prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
self.flux_pipe, prompts
)
# Prepare t and x_t
t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
x_1 = torch.randn_like(x_0).to(self.device)
t_ = t.unsqueeze(1).unsqueeze(1)
x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
# Prepare conditions
condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
# Add position delta
condition_ids[:, 1] += position_delta[0]
condition_ids[:, 2] += position_delta[1]
if position_scale != 1.0:
scale_bias = (position_scale - 1.0) / 2
condition_ids[:, 1] *= position_scale
condition_ids[:, 2] *= position_scale
condition_ids[:, 1] += scale_bias
condition_ids[:, 2] += scale_bias
# Prepare condition type
condition_type_ids = torch.tensor(
[
Condition.get_type_id(condition_type)
for condition_type in condition_types
]
).to(self.device)
condition_type_ids = (
torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
).unsqueeze(1)
# Prepare guidance
guidance = (
torch.ones_like(t).to(self.device)
if self.transformer.config.guidance_embeds
else None
)
# Forward pass
transformer_out = tranformer_forward(
self.transformer,
# Model config
model_config=self.model_config,
# Inputs of the condition (new feature)
condition_latents=condition_latents,
condition_ids=condition_ids,
condition_type_ids=condition_type_ids,
# Inputs to the original transformer
hidden_states=x_t,
timestep=t,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=img_ids,
joint_attention_kwargs=None,
return_dict=False,
)
pred = transformer_out[0]
# Compute loss
loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
self.last_t = t.mean().item()
return loss
|