Spaces:
Running
on
Zero
Running
on
Zero
from torch.utils.data import DataLoader | |
import torch | |
import lightning as L | |
import yaml | |
import os | |
import time | |
from datasets import load_dataset | |
from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset | |
from .model import OminiModel | |
from .callbacks import TrainingCallback | |
def get_rank(): | |
try: | |
rank = int(os.environ.get("LOCAL_RANK")) | |
except: | |
rank = 0 | |
return rank | |
def get_config(): | |
config_path = os.environ.get("XFL_CONFIG") | |
assert config_path is not None, "Please set the XFL_CONFIG environment variable" | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
return config | |
def init_wandb(wandb_config, run_name): | |
import wandb | |
try: | |
assert os.environ.get("WANDB_API_KEY") is not None | |
wandb.init( | |
project=wandb_config["project"], | |
name=run_name, | |
config={}, | |
) | |
except Exception as e: | |
print("Failed to initialize WanDB:", e) | |
def main(): | |
# Initialize | |
is_main_process, rank = get_rank() == 0, get_rank() | |
torch.cuda.set_device(rank) | |
config = get_config() | |
training_config = config["train"] | |
run_name = time.strftime("%Y%m%d-%H%M%S") | |
# Initialize WanDB | |
wandb_config = training_config.get("wandb", None) | |
if wandb_config is not None and is_main_process: | |
init_wandb(wandb_config, run_name) | |
print("Rank:", rank) | |
if is_main_process: | |
print("Config:", config) | |
# Initialize dataset and dataloader | |
if training_config["dataset"]["type"] == "subject": | |
dataset = load_dataset("Yuanshi/Subjects200K") | |
# Define filter function | |
def filter_func(item): | |
if not item.get("quality_assessment"): | |
return False | |
return all( | |
item["quality_assessment"].get(key, 0) >= 5 | |
for key in ["compositeStructure", "objectConsistency", "imageQuality"] | |
) | |
# Filter dataset | |
if not os.path.exists("./cache/dataset"): | |
os.makedirs("./cache/dataset") | |
data_valid = dataset["train"].filter( | |
filter_func, | |
num_proc=16, | |
cache_file_name="./cache/dataset/data_valid.arrow", | |
) | |
dataset = Subject200KDataset( | |
data_valid, | |
condition_size=training_config["dataset"]["condition_size"], | |
target_size=training_config["dataset"]["target_size"], | |
image_size=training_config["dataset"]["image_size"], | |
padding=training_config["dataset"]["padding"], | |
condition_type=training_config["condition_type"], | |
drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
) | |
elif training_config["dataset"]["type"] == "img": | |
# Load dataset text-to-image-2M | |
dataset = load_dataset( | |
"webdataset", | |
data_files={"train": training_config["dataset"]["urls"]}, | |
split="train", | |
cache_dir="cache/t2i2m", | |
num_proc=32, | |
) | |
dataset = ImageConditionDataset( | |
dataset, | |
condition_size=training_config["dataset"]["condition_size"], | |
target_size=training_config["dataset"]["target_size"], | |
condition_type=training_config["condition_type"], | |
drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
position_scale=training_config["dataset"].get("position_scale", 1.0), | |
) | |
elif training_config["dataset"]["type"] == "cartoon": | |
dataset = load_dataset("saquiboye/oye-cartoon", split="train") | |
dataset = CartoonDataset( | |
dataset, | |
condition_size=training_config["dataset"]["condition_size"], | |
target_size=training_config["dataset"]["target_size"], | |
image_size=training_config["dataset"]["image_size"], | |
padding=training_config["dataset"]["padding"], | |
condition_type=training_config["condition_type"], | |
drop_text_prob=training_config["dataset"]["drop_text_prob"], | |
drop_image_prob=training_config["dataset"]["drop_image_prob"], | |
) | |
else: | |
raise NotImplementedError | |
print("Dataset length:", len(dataset)) | |
train_loader = DataLoader( | |
dataset, | |
batch_size=training_config["batch_size"], | |
shuffle=True, | |
num_workers=training_config["dataloader_workers"], | |
) | |
# Initialize model | |
trainable_model = OminiModel( | |
flux_pipe_id=config["flux_path"], | |
lora_config=training_config["lora_config"], | |
device=f"cuda", | |
dtype=getattr(torch, config["dtype"]), | |
optimizer_config=training_config["optimizer"], | |
model_config=config.get("model", {}), | |
gradient_checkpointing=training_config.get("gradient_checkpointing", False), | |
) | |
# Callbacks for logging and saving checkpoints | |
training_callbacks = ( | |
[TrainingCallback(run_name, training_config=training_config)] | |
if is_main_process | |
else [] | |
) | |
# Initialize trainer | |
trainer = L.Trainer( | |
accumulate_grad_batches=training_config["accumulate_grad_batches"], | |
callbacks=training_callbacks, | |
enable_checkpointing=False, | |
enable_progress_bar=False, | |
logger=False, | |
max_steps=training_config.get("max_steps", -1), | |
max_epochs=training_config.get("max_epochs", -1), | |
gradient_clip_val=training_config.get("gradient_clip_val", 0.5), | |
) | |
setattr(trainer, "training_config", training_config) | |
# Save config | |
save_path = training_config.get("save_path", "./output") | |
if is_main_process: | |
os.makedirs(f"{save_path}/{run_name}") | |
with open(f"{save_path}/{run_name}/config.yaml", "w") as f: | |
yaml.dump(config, f) | |
# Start training | |
trainer.fit(trainable_model, train_loader) | |
if __name__ == "__main__": | |
main() | |