Yuanshi's picture
Upload 61 files
fb6a167 verified
import lightning as L
from PIL import Image, ImageFilter, ImageDraw
import numpy as np
from transformers import pipeline
import cv2
import torch
import os
try:
import wandb
except ImportError:
wandb = None
from ..flux.condition import Condition
from ..flux.generate import generate
class TrainingCallback(L.Callback):
def __init__(self, run_name, training_config: dict = {}):
self.run_name, self.training_config = run_name, training_config
self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
self.save_interval = training_config.get("save_interval", 1000)
self.sample_interval = training_config.get("sample_interval", 1000)
self.save_path = training_config.get("save_path", "./output")
self.wandb_config = training_config.get("wandb", None)
self.use_wandb = (
wandb is not None and os.environ.get("WANDB_API_KEY") is not None
)
self.total_steps = 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
gradient_size = 0
max_gradient_size = 0
count = 0
for _, param in pl_module.named_parameters():
if param.grad is not None:
gradient_size += param.grad.norm(2).item()
max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
count += 1
if count > 0:
gradient_size /= count
self.total_steps += 1
# Print training progress every n steps
if self.use_wandb:
report_dict = {
"steps": batch_idx,
"steps": self.total_steps,
"epoch": trainer.current_epoch,
"gradient_size": gradient_size,
}
loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
report_dict["loss"] = loss_value
report_dict["t"] = pl_module.last_t
wandb.log(report_dict)
if self.total_steps % self.print_every_n_steps == 0:
print(
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
)
# Save LoRA weights at specified intervals
if self.total_steps % self.save_interval == 0:
print(
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
)
pl_module.save_lora(
f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
)
# Generate and save a sample image at specified intervals
if self.total_steps % self.sample_interval == 0:
print(
f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
)
self.generate_a_sample(
trainer,
pl_module,
f"{self.save_path}/{self.run_name}/output",
f"lora_{self.total_steps}",
batch["condition_type"][
0
], # Use the condition type from the current batch
)
@torch.no_grad()
def generate_a_sample(
self,
trainer,
pl_module,
save_path,
file_name,
condition_type="super_resolution",
):
# TODO: change this two variables to parameters
condition_size = trainer.training_config["dataset"]["condition_size"]
target_size = trainer.training_config["dataset"]["target_size"]
position_scale = trainer.training_config["dataset"].get("position_scale", 1.0)
generator = torch.Generator(device=pl_module.device)
generator.manual_seed(42)
test_list = []
if condition_type == "subject":
test_list.extend(
[
(
Image.open("assets/test_in.jpg"),
[0, -32],
"Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
),
(
Image.open("assets/test_out.jpg"),
[0, -32],
"In a bright room. It is placed on a table.",
),
]
)
elif condition_type == "canny":
condition_img = Image.open("assets/vase_hq.jpg").resize(
(condition_size, condition_size)
)
condition_img = np.array(condition_img)
condition_img = cv2.Canny(condition_img, 100, 200)
condition_img = Image.fromarray(condition_img).convert("RGB")
test_list.append(
(
condition_img,
[0, 0],
"A beautiful vase on a table.",
{"position_scale": position_scale} if position_scale != 1.0 else {},
)
)
elif condition_type == "coloring":
condition_img = (
Image.open("assets/vase_hq.jpg")
.resize((condition_size, condition_size))
.convert("L")
.convert("RGB")
)
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
elif condition_type == "depth":
if not hasattr(self, "deepth_pipe"):
self.deepth_pipe = pipeline(
task="depth-estimation",
model="LiheYoung/depth-anything-small-hf",
device="cpu",
)
condition_img = (
Image.open("assets/vase_hq.jpg")
.resize((condition_size, condition_size))
.convert("RGB")
)
condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
test_list.append(
(
condition_img,
[0, 0],
"A beautiful vase on a table.",
{"position_scale": position_scale} if position_scale != 1.0 else {},
)
)
elif condition_type == "depth_pred":
condition_img = (
Image.open("assets/vase_hq.jpg")
.resize((condition_size, condition_size))
.convert("RGB")
)
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
elif condition_type == "deblurring":
blur_radius = 5
image = Image.open("./assets/vase_hq.jpg")
condition_img = (
image.convert("RGB")
.resize((condition_size, condition_size))
.filter(ImageFilter.GaussianBlur(blur_radius))
.convert("RGB")
)
test_list.append(
(
condition_img,
[0, 0],
"A beautiful vase on a table.",
{"position_scale": position_scale} if position_scale != 1.0 else {},
)
)
elif condition_type == "fill":
condition_img = (
Image.open("./assets/vase_hq.jpg")
.resize((condition_size, condition_size))
.convert("RGB")
)
mask = Image.new("L", condition_img.size, 0)
draw = ImageDraw.Draw(mask)
a = condition_img.size[0] // 4
b = a * 3
draw.rectangle([a, a, b, b], fill=255)
condition_img = Image.composite(
condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
)
test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
elif condition_type == "sr":
condition_img = (
Image.open("assets/vase_hq.jpg")
.resize((condition_size, condition_size))
.convert("RGB")
)
test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
elif condition_type == "cartoon":
condition_img = (
Image.open("assets/cartoon_boy.png")
.resize((condition_size, condition_size))
.convert("RGB")
)
test_list.append(
(
condition_img,
[0, -16],
"A cartoon character in a white background. He is looking right, and running.",
)
)
else:
raise NotImplementedError
if not os.path.exists(save_path):
os.makedirs(save_path)
for i, (condition_img, position_delta, prompt, *others) in enumerate(test_list):
condition = Condition(
condition_type=condition_type,
condition=condition_img.resize(
(condition_size, condition_size)
).convert("RGB"),
position_delta=position_delta,
**(others[0] if others else {}),
)
res = generate(
pl_module.flux_pipe,
prompt=prompt,
conditions=[condition],
height=target_size,
width=target_size,
generator=generator,
model_config=pl_module.model_config,
default_lora=True,
)
res.images[0].save(
os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
)