from torch.utils.data import DataLoader import torch import lightning as L import yaml import os import time from datasets import load_dataset from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset from .model import OminiModel from .callbacks import TrainingCallback def get_rank(): try: rank = int(os.environ.get("LOCAL_RANK")) except: rank = 0 return rank def get_config(): config_path = os.environ.get("XFL_CONFIG") assert config_path is not None, "Please set the XFL_CONFIG environment variable" with open(config_path, "r") as f: config = yaml.safe_load(f) return config def init_wandb(wandb_config, run_name): import wandb try: assert os.environ.get("WANDB_API_KEY") is not None wandb.init( project=wandb_config["project"], name=run_name, config={}, ) except Exception as e: print("Failed to initialize WanDB:", e) def main(): # Initialize is_main_process, rank = get_rank() == 0, get_rank() torch.cuda.set_device(rank) config = get_config() training_config = config["train"] run_name = time.strftime("%Y%m%d-%H%M%S") # Initialize WanDB wandb_config = training_config.get("wandb", None) if wandb_config is not None and is_main_process: init_wandb(wandb_config, run_name) print("Rank:", rank) if is_main_process: print("Config:", config) # Initialize dataset and dataloader if training_config["dataset"]["type"] == "subject": dataset = load_dataset("Yuanshi/Subjects200K") # Define filter function def filter_func(item): if not item.get("quality_assessment"): return False return all( item["quality_assessment"].get(key, 0) >= 5 for key in ["compositeStructure", "objectConsistency", "imageQuality"] ) # Filter dataset if not os.path.exists("./cache/dataset"): os.makedirs("./cache/dataset") data_valid = dataset["train"].filter( filter_func, num_proc=16, cache_file_name="./cache/dataset/data_valid.arrow", ) dataset = Subject200KDataset( data_valid, condition_size=training_config["dataset"]["condition_size"], target_size=training_config["dataset"]["target_size"], image_size=training_config["dataset"]["image_size"], padding=training_config["dataset"]["padding"], condition_type=training_config["condition_type"], drop_text_prob=training_config["dataset"]["drop_text_prob"], drop_image_prob=training_config["dataset"]["drop_image_prob"], ) elif training_config["dataset"]["type"] == "img": # Load dataset text-to-image-2M dataset = load_dataset( "webdataset", data_files={"train": training_config["dataset"]["urls"]}, split="train", cache_dir="cache/t2i2m", num_proc=32, ) dataset = ImageConditionDataset( dataset, condition_size=training_config["dataset"]["condition_size"], target_size=training_config["dataset"]["target_size"], condition_type=training_config["condition_type"], drop_text_prob=training_config["dataset"]["drop_text_prob"], drop_image_prob=training_config["dataset"]["drop_image_prob"], position_scale=training_config["dataset"].get("position_scale", 1.0), ) elif training_config["dataset"]["type"] == "cartoon": dataset = load_dataset("saquiboye/oye-cartoon", split="train") dataset = CartoonDataset( dataset, condition_size=training_config["dataset"]["condition_size"], target_size=training_config["dataset"]["target_size"], image_size=training_config["dataset"]["image_size"], padding=training_config["dataset"]["padding"], condition_type=training_config["condition_type"], drop_text_prob=training_config["dataset"]["drop_text_prob"], drop_image_prob=training_config["dataset"]["drop_image_prob"], ) else: raise NotImplementedError print("Dataset length:", len(dataset)) train_loader = DataLoader( dataset, batch_size=training_config["batch_size"], shuffle=True, num_workers=training_config["dataloader_workers"], ) # Initialize model trainable_model = OminiModel( flux_pipe_id=config["flux_path"], lora_config=training_config["lora_config"], device=f"cuda", dtype=getattr(torch, config["dtype"]), optimizer_config=training_config["optimizer"], model_config=config.get("model", {}), gradient_checkpointing=training_config.get("gradient_checkpointing", False), ) # Callbacks for logging and saving checkpoints training_callbacks = ( [TrainingCallback(run_name, training_config=training_config)] if is_main_process else [] ) # Initialize trainer trainer = L.Trainer( accumulate_grad_batches=training_config["accumulate_grad_batches"], callbacks=training_callbacks, enable_checkpointing=False, enable_progress_bar=False, logger=False, max_steps=training_config.get("max_steps", -1), max_epochs=training_config.get("max_epochs", -1), gradient_clip_val=training_config.get("gradient_clip_val", 0.5), ) setattr(trainer, "training_config", training_config) # Save config save_path = training_config.get("save_path", "./output") if is_main_process: os.makedirs(f"{save_path}/{run_name}") with open(f"{save_path}/{run_name}/config.yaml", "w") as f: yaml.dump(config, f) # Start training trainer.fit(trainable_model, train_loader) if __name__ == "__main__": main()