Yuanshi's picture
Upload 61 files
fb6a167 verified
from torch.utils.data import DataLoader
import torch
import lightning as L
import yaml
import os
import time
from datasets import load_dataset
from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset
from .model import OminiModel
from .callbacks import TrainingCallback
def get_rank():
try:
rank = int(os.environ.get("LOCAL_RANK"))
except:
rank = 0
return rank
def get_config():
config_path = os.environ.get("XFL_CONFIG")
assert config_path is not None, "Please set the XFL_CONFIG environment variable"
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config
def init_wandb(wandb_config, run_name):
import wandb
try:
assert os.environ.get("WANDB_API_KEY") is not None
wandb.init(
project=wandb_config["project"],
name=run_name,
config={},
)
except Exception as e:
print("Failed to initialize WanDB:", e)
def main():
# Initialize
is_main_process, rank = get_rank() == 0, get_rank()
torch.cuda.set_device(rank)
config = get_config()
training_config = config["train"]
run_name = time.strftime("%Y%m%d-%H%M%S")
# Initialize WanDB
wandb_config = training_config.get("wandb", None)
if wandb_config is not None and is_main_process:
init_wandb(wandb_config, run_name)
print("Rank:", rank)
if is_main_process:
print("Config:", config)
# Initialize dataset and dataloader
if training_config["dataset"]["type"] == "subject":
dataset = load_dataset("Yuanshi/Subjects200K")
# Define filter function
def filter_func(item):
if not item.get("quality_assessment"):
return False
return all(
item["quality_assessment"].get(key, 0) >= 5
for key in ["compositeStructure", "objectConsistency", "imageQuality"]
)
# Filter dataset
if not os.path.exists("./cache/dataset"):
os.makedirs("./cache/dataset")
data_valid = dataset["train"].filter(
filter_func,
num_proc=16,
cache_file_name="./cache/dataset/data_valid.arrow",
)
dataset = Subject200KDataset(
data_valid,
condition_size=training_config["dataset"]["condition_size"],
target_size=training_config["dataset"]["target_size"],
image_size=training_config["dataset"]["image_size"],
padding=training_config["dataset"]["padding"],
condition_type=training_config["condition_type"],
drop_text_prob=training_config["dataset"]["drop_text_prob"],
drop_image_prob=training_config["dataset"]["drop_image_prob"],
)
elif training_config["dataset"]["type"] == "img":
# Load dataset text-to-image-2M
dataset = load_dataset(
"webdataset",
data_files={"train": training_config["dataset"]["urls"]},
split="train",
cache_dir="cache/t2i2m",
num_proc=32,
)
dataset = ImageConditionDataset(
dataset,
condition_size=training_config["dataset"]["condition_size"],
target_size=training_config["dataset"]["target_size"],
condition_type=training_config["condition_type"],
drop_text_prob=training_config["dataset"]["drop_text_prob"],
drop_image_prob=training_config["dataset"]["drop_image_prob"],
position_scale=training_config["dataset"].get("position_scale", 1.0),
)
elif training_config["dataset"]["type"] == "cartoon":
dataset = load_dataset("saquiboye/oye-cartoon", split="train")
dataset = CartoonDataset(
dataset,
condition_size=training_config["dataset"]["condition_size"],
target_size=training_config["dataset"]["target_size"],
image_size=training_config["dataset"]["image_size"],
padding=training_config["dataset"]["padding"],
condition_type=training_config["condition_type"],
drop_text_prob=training_config["dataset"]["drop_text_prob"],
drop_image_prob=training_config["dataset"]["drop_image_prob"],
)
else:
raise NotImplementedError
print("Dataset length:", len(dataset))
train_loader = DataLoader(
dataset,
batch_size=training_config["batch_size"],
shuffle=True,
num_workers=training_config["dataloader_workers"],
)
# Initialize model
trainable_model = OminiModel(
flux_pipe_id=config["flux_path"],
lora_config=training_config["lora_config"],
device=f"cuda",
dtype=getattr(torch, config["dtype"]),
optimizer_config=training_config["optimizer"],
model_config=config.get("model", {}),
gradient_checkpointing=training_config.get("gradient_checkpointing", False),
)
# Callbacks for logging and saving checkpoints
training_callbacks = (
[TrainingCallback(run_name, training_config=training_config)]
if is_main_process
else []
)
# Initialize trainer
trainer = L.Trainer(
accumulate_grad_batches=training_config["accumulate_grad_batches"],
callbacks=training_callbacks,
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
max_steps=training_config.get("max_steps", -1),
max_epochs=training_config.get("max_epochs", -1),
gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
)
setattr(trainer, "training_config", training_config)
# Save config
save_path = training_config.get("save_path", "./output")
if is_main_process:
os.makedirs(f"{save_path}/{run_name}")
with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
yaml.dump(config, f)
# Start training
trainer.fit(trainable_model, train_loader)
if __name__ == "__main__":
main()