Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,035 Bytes
fb6a167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from torch.utils.data import DataLoader
import torch
import lightning as L
import yaml
import os
import time
from datasets import load_dataset
from .data import ImageConditionDataset, Subject200KDataset, CartoonDataset
from .model import OminiModel
from .callbacks import TrainingCallback
def get_rank():
try:
rank = int(os.environ.get("LOCAL_RANK"))
except:
rank = 0
return rank
def get_config():
config_path = os.environ.get("XFL_CONFIG")
assert config_path is not None, "Please set the XFL_CONFIG environment variable"
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config
def init_wandb(wandb_config, run_name):
import wandb
try:
assert os.environ.get("WANDB_API_KEY") is not None
wandb.init(
project=wandb_config["project"],
name=run_name,
config={},
)
except Exception as e:
print("Failed to initialize WanDB:", e)
def main():
# Initialize
is_main_process, rank = get_rank() == 0, get_rank()
torch.cuda.set_device(rank)
config = get_config()
training_config = config["train"]
run_name = time.strftime("%Y%m%d-%H%M%S")
# Initialize WanDB
wandb_config = training_config.get("wandb", None)
if wandb_config is not None and is_main_process:
init_wandb(wandb_config, run_name)
print("Rank:", rank)
if is_main_process:
print("Config:", config)
# Initialize dataset and dataloader
if training_config["dataset"]["type"] == "subject":
dataset = load_dataset("Yuanshi/Subjects200K")
# Define filter function
def filter_func(item):
if not item.get("quality_assessment"):
return False
return all(
item["quality_assessment"].get(key, 0) >= 5
for key in ["compositeStructure", "objectConsistency", "imageQuality"]
)
# Filter dataset
if not os.path.exists("./cache/dataset"):
os.makedirs("./cache/dataset")
data_valid = dataset["train"].filter(
filter_func,
num_proc=16,
cache_file_name="./cache/dataset/data_valid.arrow",
)
dataset = Subject200KDataset(
data_valid,
condition_size=training_config["dataset"]["condition_size"],
target_size=training_config["dataset"]["target_size"],
image_size=training_config["dataset"]["image_size"],
padding=training_config["dataset"]["padding"],
condition_type=training_config["condition_type"],
drop_text_prob=training_config["dataset"]["drop_text_prob"],
drop_image_prob=training_config["dataset"]["drop_image_prob"],
)
elif training_config["dataset"]["type"] == "img":
# Load dataset text-to-image-2M
dataset = load_dataset(
"webdataset",
data_files={"train": training_config["dataset"]["urls"]},
split="train",
cache_dir="cache/t2i2m",
num_proc=32,
)
dataset = ImageConditionDataset(
dataset,
condition_size=training_config["dataset"]["condition_size"],
target_size=training_config["dataset"]["target_size"],
condition_type=training_config["condition_type"],
drop_text_prob=training_config["dataset"]["drop_text_prob"],
drop_image_prob=training_config["dataset"]["drop_image_prob"],
position_scale=training_config["dataset"].get("position_scale", 1.0),
)
elif training_config["dataset"]["type"] == "cartoon":
dataset = load_dataset("saquiboye/oye-cartoon", split="train")
dataset = CartoonDataset(
dataset,
condition_size=training_config["dataset"]["condition_size"],
target_size=training_config["dataset"]["target_size"],
image_size=training_config["dataset"]["image_size"],
padding=training_config["dataset"]["padding"],
condition_type=training_config["condition_type"],
drop_text_prob=training_config["dataset"]["drop_text_prob"],
drop_image_prob=training_config["dataset"]["drop_image_prob"],
)
else:
raise NotImplementedError
print("Dataset length:", len(dataset))
train_loader = DataLoader(
dataset,
batch_size=training_config["batch_size"],
shuffle=True,
num_workers=training_config["dataloader_workers"],
)
# Initialize model
trainable_model = OminiModel(
flux_pipe_id=config["flux_path"],
lora_config=training_config["lora_config"],
device=f"cuda",
dtype=getattr(torch, config["dtype"]),
optimizer_config=training_config["optimizer"],
model_config=config.get("model", {}),
gradient_checkpointing=training_config.get("gradient_checkpointing", False),
)
# Callbacks for logging and saving checkpoints
training_callbacks = (
[TrainingCallback(run_name, training_config=training_config)]
if is_main_process
else []
)
# Initialize trainer
trainer = L.Trainer(
accumulate_grad_batches=training_config["accumulate_grad_batches"],
callbacks=training_callbacks,
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
max_steps=training_config.get("max_steps", -1),
max_epochs=training_config.get("max_epochs", -1),
gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
)
setattr(trainer, "training_config", training_config)
# Save config
save_path = training_config.get("save_path", "./output")
if is_main_process:
os.makedirs(f"{save_path}/{run_name}")
with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
yaml.dump(config, f)
# Start training
trainer.fit(trainable_model, train_loader)
if __name__ == "__main__":
main()
|