|
import os |
|
import hashlib |
|
import json |
|
import logging |
|
import math |
|
import datetime |
|
from datetime import timedelta |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Tuple |
|
|
|
import diffusers |
|
import torch |
|
import transformers |
|
import wandb |
|
from accelerate.accelerator import Accelerator, DistributedType |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import ( |
|
DistributedDataParallelKwargs, |
|
InitProcessGroupKwargs, |
|
ProjectConfiguration, |
|
gather_object, |
|
set_seed, |
|
broadcast_object_list, |
|
) |
|
from diffusers.optimization import get_scheduler |
|
from diffusers.pipelines import DiffusionPipeline |
|
from diffusers.utils.export_utils import export_to_video |
|
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict |
|
from PIL import Image |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
from safetensors.torch import save_file, load_file |
|
|
|
from finetune.constants import LOG_LEVEL, LOG_NAME |
|
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, I2VFlowDataset |
|
from finetune.datasets.utils import ( |
|
load_images, |
|
load_prompts, |
|
load_videos, |
|
preprocess_image_with_resize, |
|
preprocess_video_with_resize, |
|
) |
|
from finetune.schemas import Args, Components, State |
|
from finetune.utils import ( |
|
cast_training_params, |
|
free_memory, |
|
get_intermediate_ckpt_path, |
|
get_latest_ckpt_path_to_resume_from, |
|
get_memory_statistics, |
|
get_optimizer, |
|
string_to_filename, |
|
unload_model, |
|
unwrap_model, |
|
) |
|
|
|
from tqdm import tqdm |
|
import pdb |
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL) |
|
|
|
_DTYPE_MAP = { |
|
"fp32": torch.float32, |
|
"fp16": torch.float16, |
|
"bf16": torch.bfloat16, |
|
} |
|
|
|
|
|
class Trainer: |
|
|
|
UNLOAD_LIST: List[str] = None |
|
|
|
def __init__(self, args: Args) -> None: |
|
self.args = args |
|
self.state = State( |
|
weight_dtype=self.__get_training_dtype(), |
|
train_frames=self.args.train_resolution[0], |
|
train_height=self.args.train_resolution[1], |
|
train_width=self.args.train_resolution[2], |
|
) |
|
|
|
self.components: Components = self.load_components() |
|
self.accelerator: Accelerator = None |
|
self.dataset: Dataset = None |
|
self.data_loader: DataLoader = None |
|
|
|
self.optimizer = None |
|
self.lr_scheduler = None |
|
|
|
self._init_distributed() |
|
self._init_logging() |
|
|
|
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None |
|
|
|
|
|
def _init_distributed(self): |
|
project_dir = Path(self.args.output_dir) |
|
logging_dir = project_dir / "tmp_logs" |
|
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) |
|
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
init_process_group_kwargs = InitProcessGroupKwargs( |
|
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) |
|
) |
|
mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision |
|
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to |
|
|
|
accelerator = Accelerator( |
|
project_config=project_config, |
|
gradient_accumulation_steps=self.args.gradient_accumulation_steps, |
|
mixed_precision=mixed_precision, |
|
log_with=report_to, |
|
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], |
|
) |
|
|
|
run_id = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") if accelerator.is_main_process else "" |
|
[run_id] = broadcast_object_list([run_id]) |
|
|
|
final_out_dir = project_dir / f"{self.args.run_name}-{run_id}" |
|
final_log_dir = final_out_dir / "logs" |
|
|
|
if accelerator.is_main_process: |
|
final_log_dir.mkdir(parents=True, exist_ok=True) |
|
accelerator.wait_for_everyone() |
|
|
|
self.args.output_dir = final_out_dir |
|
accelerator.project_configuration.project_dir = final_out_dir |
|
accelerator.project_configuration.logging_dir = final_log_dir |
|
|
|
accelerator.init_trackers( |
|
project_name=self.args.model_name, |
|
config=vars(self.args), |
|
init_kwargs={ |
|
"wandb": { |
|
"dir": final_log_dir, |
|
"name": self.args.run_name, |
|
} |
|
} |
|
) |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
accelerator.native_amp = False |
|
|
|
self.accelerator = accelerator |
|
|
|
if self.args.seed is not None: |
|
set_seed(self.args.seed) |
|
|
|
def _init_logging(self) -> None: |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=LOG_LEVEL, |
|
) |
|
if self.accelerator.is_local_main_process: |
|
transformers.utils.logging.set_verbosity_warning() |
|
diffusers.utils.logging.set_verbosity_info() |
|
else: |
|
transformers.utils.logging.set_verbosity_error() |
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
logger.info("Initialized Trainer") |
|
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) |
|
|
|
|
|
def check_setting(self) -> None: |
|
|
|
if self.UNLOAD_LIST is None: |
|
logger.warning( |
|
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m" |
|
) |
|
else: |
|
for name in self.UNLOAD_LIST: |
|
if name not in self.components.model_fields: |
|
raise ValueError(f"Invalid component name in unload_list: {name}") |
|
|
|
def prepare_models(self) -> None: |
|
logger.info("Initializing models") |
|
|
|
if self.components.vae is not None: |
|
if self.args.enable_slicing: |
|
self.components.vae.enable_slicing() |
|
if self.args.enable_tiling: |
|
self.components.vae.enable_tiling() |
|
|
|
self.state.transformer_config = self.components.transformer.config |
|
|
|
def prepare_dataset(self) -> None: |
|
logger.info("Initializing dataset and dataloader") |
|
|
|
if self.args.model_type == "i2v": |
|
self.dataset = I2VDatasetWithResize( |
|
**(self.args.model_dump()), |
|
device=self.accelerator.device, |
|
max_num_frames=self.state.train_frames, |
|
height=self.state.train_height, |
|
width=self.state.train_width, |
|
trainer=self, |
|
) |
|
elif self.args.model_type == "t2v": |
|
self.dataset = T2VDatasetWithResize( |
|
**(self.args.model_dump()), |
|
device=self.accelerator.device, |
|
max_num_frames=self.state.train_frames, |
|
height=self.state.train_height, |
|
width=self.state.train_width, |
|
trainer=self, |
|
) |
|
elif self.args.model_type == "i2vFlow": |
|
self.dataset = I2VFlowDataset( |
|
**(self.args.model_dump()), |
|
device=self.accelerator.device, |
|
max_num_frames=self.state.train_frames, |
|
height=self.state.train_height, |
|
width=self.state.train_width, |
|
trainer=self, |
|
) |
|
else: |
|
raise ValueError(f"Invalid model type: {self.args.model_type}") |
|
|
|
|
|
if self.args.training_type == "controlnet": |
|
self.components.transformer.requires_grad_(False) |
|
self.components.vae.requires_grad_(False) |
|
self.components.text_encoder.requires_grad_(False) |
|
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) |
|
self.components.text_encoder = self.components.text_encoder.to( |
|
self.accelerator.device, dtype=self.state.weight_dtype |
|
) |
|
|
|
if not self.args.model_type == "i2vFlow": |
|
|
|
logger.info("Precomputing latent for video and prompt embedding ...") |
|
tmp_data_loader = torch.utils.data.DataLoader( |
|
self.dataset, |
|
collate_fn=self.collate_fn, |
|
batch_size=1, |
|
num_workers=0, |
|
pin_memory=self.args.pin_memory, |
|
) |
|
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) |
|
for _ in tqdm(tmp_data_loader, desc="prepare dataloader"): |
|
... |
|
self.accelerator.wait_for_everyone() |
|
logger.info("Precomputing latent for video and prompt embedding ... Done") |
|
|
|
unload_model(self.components.vae) |
|
unload_model(self.components.text_encoder) |
|
free_memory() |
|
|
|
self.data_loader = torch.utils.data.DataLoader( |
|
self.dataset, |
|
collate_fn=self.collate_fn, |
|
batch_size=self.args.batch_size, |
|
num_workers=self.args.num_workers, |
|
pin_memory=self.args.pin_memory, |
|
shuffle=True, |
|
) |
|
|
|
def set_additional_trainable_parameters(self, block_names): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.components.transformer.patch_embed.proj.requires_grad_(True) |
|
|
|
def prepare_trainable_parameters(self): |
|
logger.info("Initializing trainable parameters") |
|
|
|
|
|
|
|
weight_dtype = self.state.weight_dtype |
|
|
|
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: |
|
|
|
raise ValueError( |
|
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." |
|
) |
|
|
|
|
|
|
|
for attr_name, component in vars(self.components).items(): |
|
if hasattr(component, "requires_grad_"): |
|
if self.args.training_type == "sft" and attr_name == "transformer": |
|
component.requires_grad_(True) |
|
elif self.args.training_type == "controlnet" and attr_name == "controlnet": |
|
component.requires_grad_(True) |
|
if self.args.notextinflow: |
|
component.patch_embed.text_proj.requires_grad_(False) |
|
else: |
|
component.requires_grad_(False) |
|
|
|
if self.args.training_type == "lora": |
|
transformer_lora_config = LoraConfig( |
|
r=self.args.rank, |
|
lora_alpha=self.args.lora_alpha, |
|
init_lora_weights=True, |
|
target_modules=self.args.target_modules, |
|
) |
|
self.components.transformer.add_adapter(transformer_lora_config) |
|
self.__prepare_saving_loading_hooks(transformer_lora_config, block_names=self.args.additional_save_blocks) |
|
|
|
|
|
self.set_additional_trainable_parameters(block_names=self.args.additional_save_blocks) |
|
|
|
|
|
|
|
ignore_list = self.UNLOAD_LIST |
|
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list) |
|
|
|
if self.args.gradient_checkpointing: |
|
self.components.transformer.enable_gradient_checkpointing() |
|
if self.args.training_type == "controlnet": |
|
self.components.controlnet.enable_gradient_checkpointing() |
|
|
|
def prepare_optimizer(self) -> None: |
|
logger.info("Initializing optimizer and lr scheduler") |
|
|
|
|
|
if self.args.training_type == "sft" or self.args.training_type == "lora": |
|
cast_training_params([self.components.transformer], dtype=torch.float32) |
|
|
|
|
|
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) |
|
trainable_parameters_name = [p[0] for p in filter(lambda p: p[1].requires_grad, self.components.transformer.named_parameters())] |
|
elif self.args.training_type == "controlnet": |
|
cast_training_params([self.components.controlnet], dtype=torch.float32) |
|
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.controlnet.parameters())) |
|
trainable_parameters_name = [p[0] for p in filter(lambda p: p[1].requires_grad, self.components.controlnet.named_parameters())] |
|
else: |
|
raise NotImplementedError("Choose training_type among 'sft', 'lora', 'controlnet'") |
|
|
|
|
|
|
|
|
|
print("-"*200) |
|
print(f"Training type: {self.args.training_type}") |
|
print(f"Trainable parameters: {trainable_parameters_name}") |
|
print("-"*200) |
|
|
|
|
|
trainable_parameters_with_lr = { |
|
"params": trainable_parameters, |
|
"lr": self.args.learning_rate, |
|
} |
|
params_to_optimize = [trainable_parameters_with_lr] |
|
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters) |
|
|
|
use_deepspeed_opt = ( |
|
self.accelerator.state.deepspeed_plugin is not None |
|
and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config |
|
) |
|
optimizer = get_optimizer( |
|
params_to_optimize=params_to_optimize, |
|
optimizer_name=self.args.optimizer, |
|
learning_rate=self.args.learning_rate, |
|
beta1=self.args.beta1, |
|
beta2=self.args.beta2, |
|
beta3=self.args.beta3, |
|
epsilon=self.args.epsilon, |
|
weight_decay=self.args.weight_decay, |
|
use_deepspeed=use_deepspeed_opt, |
|
) |
|
|
|
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) |
|
if self.args.train_steps is None: |
|
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch |
|
self.state.overwrote_max_train_steps = True |
|
|
|
use_deepspeed_lr_scheduler = ( |
|
self.accelerator.state.deepspeed_plugin is not None |
|
and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config |
|
) |
|
total_training_steps = self.args.train_steps * self.accelerator.num_processes |
|
num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes |
|
|
|
if use_deepspeed_lr_scheduler: |
|
from accelerate.utils import DummyScheduler |
|
|
|
lr_scheduler = DummyScheduler( |
|
name=self.args.lr_scheduler, |
|
optimizer=optimizer, |
|
total_num_steps=total_training_steps, |
|
num_warmup_steps=num_warmup_steps, |
|
) |
|
else: |
|
lr_scheduler = get_scheduler( |
|
name=self.args.lr_scheduler, |
|
optimizer=optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=total_training_steps, |
|
num_cycles=self.args.lr_num_cycles, |
|
power=self.args.lr_power, |
|
) |
|
|
|
self.optimizer = optimizer |
|
self.lr_scheduler = lr_scheduler |
|
|
|
def prepare_for_training(self) -> None: |
|
if self.args.training_type == "sft" or self.args.training_type == "lora": |
|
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( |
|
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler |
|
) |
|
elif self.args.training_type == "controlnet": |
|
self.components.controlnet, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( |
|
self.components.controlnet, self.optimizer, self.data_loader, self.lr_scheduler |
|
) |
|
|
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) |
|
if self.state.overwrote_max_train_steps: |
|
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch |
|
|
|
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) |
|
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch |
|
|
|
def prepare_for_validation(self): |
|
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) |
|
|
|
if self.args.validation_images is not None: |
|
validation_images = load_images(self.args.validation_dir / self.args.validation_images) |
|
else: |
|
validation_images = [None] * len(validation_prompts) |
|
|
|
if self.args.validation_videos is not None: |
|
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos) |
|
else: |
|
validation_videos = [None] * len(validation_prompts) |
|
|
|
self.state.validation_prompts = validation_prompts |
|
self.state.validation_images = validation_images |
|
self.state.validation_videos = validation_videos |
|
|
|
self.validate(0) |
|
|
|
def prepare_trackers(self) -> None: |
|
logger.info("Initializing trackers") |
|
|
|
tracker_name = self.args.tracker_name or "finetrainers-experiment" |
|
self.accelerator.init_trackers(tracker_name, config=self.args.model_dump()) |
|
|
|
def load_state_single_gpu(self, resume_from_checkpoint_path) -> None: |
|
state_dict_path = resume_from_checkpoint_path / "pytorch_model" / "mp_rank_00_model_states.pt" |
|
state_dict = torch.load(state_dict_path)['module'] |
|
if self.args.training_type == "controlnet": |
|
controlnet_ = unwrap_model(self.accelerator, self.components.controlnet) |
|
controlnet_.load_state_dict(state_dict) |
|
|
|
def train(self) -> None: |
|
|
|
logger.info("Starting training") |
|
|
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") |
|
|
|
self.state.total_batch_size_count = ( |
|
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps |
|
) |
|
info = { |
|
"trainable parameters": self.state.num_trainable_parameters, |
|
"total samples": len(self.dataset), |
|
"train epochs": self.args.train_epochs, |
|
"train steps": self.args.train_steps, |
|
"batches per device": self.args.batch_size, |
|
"total batches observed per epoch": len(self.data_loader), |
|
"train batch size total count": self.state.total_batch_size_count, |
|
"gradient accumulation steps": self.args.gradient_accumulation_steps, |
|
} |
|
logger.info(f"Training configuration: {json.dumps(info, indent=4)}") |
|
|
|
global_step = 0 |
|
first_epoch = 0 |
|
initial_global_step = 0 |
|
|
|
|
|
( |
|
resume_from_checkpoint_path, |
|
initial_global_step, |
|
global_step, |
|
first_epoch, |
|
) = get_latest_ckpt_path_to_resume_from( |
|
resume_from_checkpoint=self.args.resume_from_checkpoint, |
|
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, |
|
) |
|
|
|
|
|
if resume_from_checkpoint_path is not None: |
|
self.accelerator.load_state(resume_from_checkpoint_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress_bar = tqdm( |
|
range(0, self.args.train_steps), |
|
initial=initial_global_step, |
|
desc="Training steps", |
|
disable=not self.accelerator.is_local_main_process, |
|
) |
|
|
|
accelerator = self.accelerator |
|
generator = torch.Generator(device=accelerator.device) |
|
if self.args.seed is not None: |
|
generator = generator.manual_seed(self.args.seed) |
|
self.state.generator = generator |
|
|
|
last_validated_step = -1 |
|
if global_step != 0: |
|
last_validated_step = global_step |
|
|
|
free_memory() |
|
for epoch in range(first_epoch, self.args.train_epochs): |
|
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") |
|
|
|
if self.args.training_type == "sft" or self.args.training_type == "lora": |
|
self.components.transformer.train() |
|
models_to_accumulate = [self.components.transformer] |
|
elif self.args.training_type == "controlnet": |
|
self.components.controlnet.train() |
|
models_to_accumulate = [self.components.controlnet] |
|
|
|
for step, batch in enumerate(self.data_loader): |
|
logger.debug(f"Starting step {step + 1}") |
|
logs = {} |
|
|
|
with accelerator.accumulate(models_to_accumulate): |
|
|
|
loss = self.compute_loss(batch) |
|
accelerator.backward(loss) |
|
|
|
if accelerator.sync_gradients: |
|
if accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
if self.args.training_type == "sft" or self.args.training_type == "lora": |
|
grad_norm = self.components.transformer.get_global_grad_norm() |
|
elif self.args.training_type == "controlnet": |
|
grad_norm = self.components.controlnet.get_global_grad_norm() |
|
|
|
if torch.is_tensor(grad_norm): |
|
grad_norm = grad_norm.item() |
|
else: |
|
if self.args.training_type == "sft" or self.args.training_type == "lora": |
|
param_to_clip = self.components.transformer.parameters() |
|
elif self.args.training_type == "controlnet": |
|
param_to_clip = self.components.controlnet.parameters() |
|
grad_norm = accelerator.clip_grad_norm_( |
|
param_to_clip, self.args.max_grad_norm |
|
) |
|
if torch.is_tensor(grad_norm): |
|
grad_norm = grad_norm.item() |
|
|
|
logs["grad_norm"] = grad_norm |
|
|
|
self.optimizer.step() |
|
self.lr_scheduler.step() |
|
self.optimizer.zero_grad() |
|
|
|
|
|
if accelerator.sync_gradients: |
|
progress_bar.update(1) |
|
global_step += 1 |
|
self.__maybe_save_checkpoint(global_step) |
|
|
|
|
|
logs["loss"] = loss.detach().item() |
|
logs["lr"] = self.lr_scheduler.get_last_lr()[0] |
|
progress_bar.set_postfix(logs) |
|
|
|
|
|
should_run_validation = ( |
|
self.args.do_validation and |
|
global_step % self.args.validation_steps == 0 and |
|
global_step != 0 and |
|
global_step != last_validated_step |
|
) |
|
|
|
if should_run_validation: |
|
del loss |
|
free_memory() |
|
self.validate(global_step) |
|
should_run_validation = False |
|
last_validated_step = global_step |
|
|
|
accelerator.log(logs, step=global_step) |
|
|
|
if global_step >= self.args.train_steps: |
|
break |
|
|
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") |
|
|
|
accelerator.wait_for_everyone() |
|
self.__maybe_save_checkpoint(global_step, must_save=True) |
|
if self.args.do_validation: |
|
free_memory() |
|
self.validate(global_step) |
|
|
|
del self.components |
|
free_memory() |
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") |
|
|
|
accelerator.end_training() |
|
|
|
|
|
|
|
def validate(self, step: int) -> None: |
|
logger.info("Starting validation") |
|
|
|
accelerator = self.accelerator |
|
num_validation_samples = len(self.state.validation_prompts) |
|
|
|
if num_validation_samples == 0: |
|
logger.warning("No validation samples found. Skipping validation.") |
|
return |
|
|
|
self.components.transformer.eval() |
|
torch.set_grad_enabled(False) |
|
|
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") |
|
|
|
|
|
pipe = self.initialize_pipeline() |
|
|
|
if self.state.using_deepspeed: |
|
|
|
|
|
|
|
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"]) |
|
else: |
|
|
|
|
|
pipe.enable_model_cpu_offload(device=self.accelerator.device) |
|
|
|
|
|
|
|
pipe = pipe.to(dtype=self.state.weight_dtype) |
|
|
|
|
|
|
|
all_processes_artifacts = [] |
|
for i in range(num_validation_samples): |
|
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3: |
|
|
|
if i % accelerator.num_processes != accelerator.process_index: |
|
continue |
|
|
|
prompt = self.state.validation_prompts[i] |
|
image = self.state.validation_images[i] |
|
video = self.state.validation_videos[i] |
|
|
|
if image is not None: |
|
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) |
|
|
|
image = image.to(torch.uint8) |
|
image = image.permute(1, 2, 0).cpu().numpy() |
|
image = Image.fromarray(image) |
|
|
|
if video is not None: |
|
video = preprocess_video_with_resize( |
|
video, self.state.train_frames, self.state.train_height, self.state.train_width |
|
) |
|
|
|
video = video.round().clamp(0, 255).to(torch.uint8) |
|
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] |
|
|
|
logger.debug( |
|
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", |
|
main_process_only=False, |
|
) |
|
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) |
|
|
|
if ( |
|
self.state.using_deepspeed |
|
and self.accelerator.deepspeed_plugin.zero_stage == 3 |
|
and not accelerator.is_main_process |
|
): |
|
continue |
|
|
|
prompt_filename = string_to_filename(prompt)[:25] |
|
|
|
reversed_prompt = prompt[::-1] |
|
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5] |
|
|
|
artifacts = { |
|
"image": {"type": "image", "value": image}, |
|
"video": {"type": "video", "value": video}, |
|
} |
|
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): |
|
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) |
|
logger.debug( |
|
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", |
|
main_process_only=False, |
|
) |
|
|
|
for key, value in list(artifacts.items()): |
|
artifact_type = value["type"] |
|
artifact_value = value["value"] |
|
if artifact_type not in ["image", "video"] or artifact_value is None: |
|
continue |
|
|
|
extension = "png" if artifact_type == "image" else "mp4" |
|
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}" |
|
validation_path = self.args.output_dir / "validation_res" |
|
validation_path.mkdir(parents=True, exist_ok=True) |
|
filename = str(validation_path / filename) |
|
|
|
if artifact_type == "image": |
|
logger.debug(f"Saving image to {filename}") |
|
artifact_value.save(filename) |
|
artifact_value = wandb.Image(filename) |
|
elif artifact_type == "video": |
|
logger.debug(f"Saving video to {filename}") |
|
export_to_video(artifact_value, filename, fps=self.args.gen_fps) |
|
artifact_value = wandb.Video(filename, caption=prompt) |
|
|
|
all_processes_artifacts.append(artifact_value) |
|
|
|
all_artifacts = gather_object(all_processes_artifacts) |
|
|
|
if accelerator.is_main_process: |
|
tracker_key = "validation" |
|
for tracker in accelerator.trackers: |
|
if tracker.name == "wandb": |
|
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] |
|
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] |
|
tracker.log( |
|
{ |
|
tracker_key: {"images": image_artifacts, "videos": video_artifacts}, |
|
}, |
|
step=step, |
|
) |
|
|
|
|
|
if self.state.using_deepspeed: |
|
del pipe |
|
|
|
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST) |
|
else: |
|
pipe.remove_all_hooks() |
|
del pipe |
|
|
|
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) |
|
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) |
|
|
|
|
|
cast_training_params([self.components.transformer], dtype=torch.float32) |
|
|
|
free_memory() |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") |
|
torch.cuda.reset_peak_memory_stats(accelerator.device) |
|
|
|
torch.set_grad_enabled(True) |
|
self.components.transformer.train() |
|
|
|
def fit(self): |
|
self.check_setting() |
|
self.prepare_models() |
|
self.prepare_dataset() |
|
self.prepare_trainable_parameters() |
|
self.prepare_optimizer() |
|
self.prepare_for_training() |
|
self.prepare_trackers() |
|
if self.args.do_validation: |
|
self.prepare_for_validation() |
|
self.train() |
|
|
|
def collate_fn(self, examples: List[Dict[str, Any]]): |
|
raise NotImplementedError |
|
|
|
def load_components(self) -> Components: |
|
raise NotImplementedError |
|
|
|
def initialize_pipeline(self) -> DiffusionPipeline: |
|
raise NotImplementedError |
|
|
|
def encode_video(self, video: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
raise NotImplementedError |
|
|
|
def encode_text(self, text: str) -> torch.Tensor: |
|
|
|
raise NotImplementedError |
|
|
|
def compute_loss(self, batch) -> torch.Tensor: |
|
raise NotImplementedError |
|
|
|
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]: |
|
raise NotImplementedError |
|
|
|
def __get_training_dtype(self) -> torch.dtype: |
|
if self.args.mixed_precision == "no": |
|
return _DTYPE_MAP["fp32"] |
|
elif self.args.mixed_precision == "fp16": |
|
return _DTYPE_MAP["fp16"] |
|
elif self.args.mixed_precision == "bf16": |
|
return _DTYPE_MAP["bf16"] |
|
else: |
|
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") |
|
|
|
def __move_components_to_device(self, dtype, ignore_list: List[str] = []): |
|
ignore_list = set(ignore_list) |
|
components = self.components.model_dump() |
|
for name, component in components.items(): |
|
if not isinstance(component, type) and hasattr(component, "to"): |
|
if name not in ignore_list: |
|
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) |
|
|
|
def __move_components_to_cpu(self, unload_list: List[str] = []): |
|
unload_list = set(unload_list) |
|
components = self.components.model_dump() |
|
for name, component in components.items(): |
|
if not isinstance(component, type) and hasattr(component, "to"): |
|
if name in unload_list: |
|
setattr(self.components, name, component.to("cpu")) |
|
|
|
def __prepare_saving_loading_hooks(self, transformer_lora_config, block_names=[]): |
|
|
|
def save_model_hook(models, weights, output_dir): |
|
if self.accelerator.is_main_process: |
|
transformer_lora_layers_to_save = None |
|
|
|
for model in models: |
|
if isinstance( |
|
unwrap_model(self.accelerator, model), |
|
type(unwrap_model(self.accelerator, self.components.transformer)), |
|
): |
|
model = unwrap_model(self.accelerator, model) |
|
|
|
transformer_lora_layers_to_save = get_peft_model_state_dict(model) |
|
|
|
|
|
if len(block_names) != 0: |
|
tensor_dict = {} |
|
for block_name in block_names: |
|
if hasattr(model, block_name): |
|
block = getattr(model, block_name) |
|
for k, v in block.state_dict().items(): |
|
tensor_dict[f"{block_name}.{k}"] = v |
|
else: |
|
raise ValueError(f"Model has no attribute '{block_name}'") |
|
else: |
|
raise ValueError(f"Unexpected save model: {model.__class__}") |
|
|
|
|
|
if weights: |
|
weights.pop() |
|
|
|
|
|
self.components.pipeline_cls.save_lora_weights( |
|
output_dir, |
|
transformer_lora_layers=transformer_lora_layers_to_save, |
|
) |
|
|
|
if len(block_names) != 0: |
|
save_path = os.path.join(output_dir, "selected_blocks.safetensors") |
|
save_file(tensor_dict, save_path) |
|
|
|
def load_model_hook(models, input_dir): |
|
if not self.accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
while len(models) > 0: |
|
model = models.pop() |
|
if isinstance( |
|
unwrap_model(self.accelerator, model), |
|
type(unwrap_model(self.accelerator, self.components.transformer)), |
|
): |
|
transformer_ = unwrap_model(self.accelerator, model) |
|
else: |
|
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}") |
|
else: |
|
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( |
|
self.args.model_path, subfolder="transformer" |
|
) |
|
transformer_.add_adapter(transformer_lora_config) |
|
|
|
|
|
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) |
|
transformer_state_dict = { |
|
f'{k.replace("transformer.", "")}': v |
|
for k, v in lora_state_dict.items() |
|
if k.startswith("transformer.") |
|
} |
|
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") |
|
if incompatible_keys is not None: |
|
|
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
|
if unexpected_keys: |
|
logger.warning( |
|
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " |
|
f" {unexpected_keys}. " |
|
) |
|
|
|
|
|
load_path = os.path.join(input_dir, "selected_blocks.safetensors") |
|
if os.path.exists(load_path): |
|
tensor_dict = load_file(load_path) |
|
|
|
block_state_dicts = {} |
|
for k, v in tensor_dict.items(): |
|
block_name, param_name = k.split(".", 1) |
|
if block_name not in block_state_dicts: |
|
block_state_dicts[block_name] = {} |
|
block_state_dicts[block_name][param_name] = v |
|
|
|
for block_name, state_dict in block_state_dicts.items(): |
|
if hasattr(transformer_, block_name): |
|
getattr(transformer_, block_name).load_state_dict(state_dict) |
|
else: |
|
raise ValueError(f"Transformer has no attribute '{block_name}'") |
|
|
|
|
|
for state in self.optimizer.state.values(): |
|
for k, v in state.items(): |
|
if isinstance(v, torch.Tensor): |
|
state[k] = v.to(device=self.accelerator.device, dtype=torch.float32) |
|
|
|
self.accelerator.register_save_state_pre_hook(save_model_hook) |
|
self.accelerator.register_load_state_pre_hook(load_model_hook) |
|
|
|
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): |
|
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: |
|
if must_save or global_step % self.args.checkpointing_steps == 0: |
|
|
|
save_path = get_intermediate_ckpt_path( |
|
checkpointing_limit=self.args.checkpointing_limit, |
|
step=global_step, |
|
output_dir=self.args.output_dir, |
|
) |
|
self.accelerator.save_state(save_path, safe_serialization=True) |
|
|