ai-toolkit / jobs /process /TrainVAEProcess.py
rahul7star's picture
boilerplate
fcc02a2 verified
raw
history blame
26 kB
import copy
import glob
import os
import shutil
import time
from collections import OrderedDict
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torch import nn
from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
from tqdm import tqdm
import time
import numpy as np
from .models.vgg19_critic import Critic
from torchvision.transforms import Resize
import lpips
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def unnormalize(tensor):
return (tensor / 2 + 0.5).clamp(0, 1)
class TrainVAEProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.data_loader = None
self.vae = None
self.device = self.get_conf('device', self.job.device)
self.vae_path = self.get_conf('vae_path', required=True)
self.datasets_objects = self.get_conf('datasets', required=True)
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
self.resolution = self.get_conf('resolution', 256, as_type=int)
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
self.sample_every = self.get_conf('sample_every', None)
self.optimizer_type = self.get_conf('optimizer', 'adam')
self.epochs = self.get_conf('epochs', None, as_type=int)
self.max_steps = self.get_conf('max_steps', None, as_type=int)
self.save_every = self.get_conf('save_every', None)
self.dtype = self.get_conf('dtype', 'float32')
self.sample_sources = self.get_conf('sample_sources', None)
self.log_every = self.get_conf('log_every', 100, as_type=int)
self.style_weight = self.get_conf('style_weight', 0, as_type=float)
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
self.optimizer_params = self.get_conf('optimizer_params', {})
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.torch_dtype = get_torch_dtype(self.dtype)
self.vgg_19 = None
self.style_weight_scalers = []
self.content_weight_scalers = []
self.lpips_loss:lpips.LPIPS = None
self.vae_scale_factor = 8
self.step_num = 0
self.epoch_num = 0
self.use_critic = self.get_conf('use_critic', False, as_type=bool)
self.critic = None
if self.use_critic:
self.critic = Critic(
device=self.device,
dtype=self.dtype,
process=self,
**self.get_conf('critic', {}) # pass any other params
)
if self.sample_every is not None and self.sample_sources is None:
raise ValueError('sample_every is specified but sample_sources is not')
if self.epochs is None and self.max_steps is None:
raise ValueError('epochs or max_steps must be specified')
self.data_loaders = []
# check datasets
assert isinstance(self.datasets_objects, list)
for dataset in self.datasets_objects:
if 'path' not in dataset:
raise ValueError('dataset must have a path')
# check if is dir
if not os.path.isdir(dataset['path']):
raise ValueError(f"dataset path does is not a directory: {dataset['path']}")
# make training folder
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
self._pattern_loss = None
def update_training_metadata(self):
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
def get_training_info(self):
info = OrderedDict({
'step': self.step_num,
'epoch': self.epoch_num,
})
return info
def load_datasets(self):
if self.data_loader is None:
print(f"Loading datasets")
datasets = []
for dataset in self.datasets_objects:
print(f" - Dataset: {dataset['path']}")
ds = copy.copy(dataset)
ds['resolution'] = self.resolution
image_dataset = ImageDataset(ds)
datasets.append(image_dataset)
concatenated_dataset = ConcatDataset(datasets)
self.data_loader = DataLoader(
concatenated_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=6
)
def remove_oldest_checkpoint(self):
max_to_keep = 4
folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
if len(folders) > max_to_keep:
folders.sort(key=os.path.getmtime)
for folder in folders[:-max_to_keep]:
print(f"Removing {folder}")
shutil.rmtree(folder)
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
single_target=True,
device=self.device,
output_layer_name='pool_4',
dtype=self.torch_dtype
)
self.vgg_19.to(self.device, dtype=self.torch_dtype)
self.vgg_19.requires_grad_(False)
# we run random noise through first to get layer scalers to normalize the loss per layer
# bs of 2 because we run pred and target through stacked
noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
self.vgg_19(noise)
for style_loss in self.style_losses:
# get a scaler to normalize to 1
scaler = 1 / torch.mean(style_loss.loss).item()
self.style_weight_scalers.append(scaler)
for content_loss in self.content_losses:
# get a scaler to normalize to 1
scaler = 1 / torch.mean(content_loss.loss).item()
self.content_weight_scalers.append(scaler)
self.print(f"Style weight scalers: {self.style_weight_scalers}")
self.print(f"Content weight scalers: {self.content_weight_scalers}")
def get_style_loss(self):
if self.style_weight > 0:
# scale all losses with loss scalers
loss = torch.sum(
torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_content_loss(self):
if self.content_weight > 0:
# scale all losses with loss scalers
loss = torch.sum(torch.stack(
[loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_mse_loss(self, pred, target):
if self.mse_weight > 0:
loss_fn = nn.MSELoss()
loss = loss_fn(pred, target)
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_kld_loss(self, mu, log_var):
if self.kld_weight > 0:
# Kullback-Leibler divergence
# added here for full training (not implemented). Not needed for only decoder
# as we are not changing the distribution of the latent space
# normally it would help keep a normal distribution for latents
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence
return KLD
else:
return torch.tensor(0.0, device=self.device)
def get_tv_loss(self, pred, target):
if self.tv_weight > 0:
get_tv_loss = ComparativeTotalVariation()
loss = get_tv_loss(pred, target)
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_pattern_loss(self, pred, target):
if self._pattern_loss is None:
self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device,
dtype=self.torch_dtype)
loss = torch.mean(self._pattern_loss(pred, target))
return loss
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
step_num = ''
if step is not None:
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
self.update_training_metadata()
filename = f'{self.job.name}{step_num}_diffusers'
self.vae = self.vae.to("cpu", dtype=torch.float16)
self.vae.save_pretrained(
save_directory=os.path.join(self.save_root, filename)
)
self.vae = self.vae.to(self.device, dtype=self.torch_dtype)
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
if self.use_critic:
self.critic.save(step)
self.remove_oldest_checkpoint()
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
os.makedirs(sample_folder, exist_ok=True)
with torch.no_grad():
for i, img_url in enumerate(self.sample_sources):
img = exif_transpose(Image.open(img_url))
img = img.convert('RGB')
# crop if not square
if img.width != img.height:
min_dim = min(img.width, img.height)
img = img.crop((0, 0, min_dim, min_dim))
# resize
img = img.resize((self.resolution, self.resolution))
input_img = img
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
img = img
decoded = self.vae(img).sample
decoded = (decoded / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
# stack input image and decoded image
input_img = input_img.resize((self.resolution, self.resolution))
decoded = decoded.resize((self.resolution, self.resolution))
output_img = Image.new('RGB', (self.resolution * 2, self.resolution))
output_img.paste(input_img, (0, 0))
output_img.paste(decoded, (self.resolution, 0))
scale_up = 2
if output_img.height <= 300:
scale_up = 4
# scale up using nearest neighbor
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
step_num = ''
if step is not None:
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename))
def load_vae(self):
path_to_load = self.vae_path
# see if we have a checkpoint in out output to resume from
self.print(f"Looking for latest checkpoint in {self.save_root}")
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
path_to_load = latest_file
# todo update step and epoch count
else:
self.print(f" - No checkpoint found, starting from scratch")
# load vae
self.print(f"Loading VAE")
self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None:
self.vae = AutoencoderKL.from_pretrained(path_to_load)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
def run(self):
super().run()
self.load_datasets()
max_step_epochs = self.max_steps // len(self.data_loader)
num_epochs = self.epochs
if num_epochs is None or num_epochs > max_step_epochs:
num_epochs = max_step_epochs
max_epoch_steps = len(self.data_loader) * num_epochs
num_steps = self.max_steps
if num_steps is None or num_steps > max_epoch_steps:
num_steps = max_epoch_steps
self.max_steps = num_steps
self.epochs = num_epochs
start_step = self.step_num
self.first_step = start_step
self.print(f"Training VAE")
self.print(f" - Training folder: {self.training_folder}")
self.print(f" - Batch size: {self.batch_size}")
self.print(f" - Learning rate: {self.learning_rate}")
self.print(f" - Epochs: {num_epochs}")
self.print(f" - Max steps: {self.max_steps}")
# load vae
self.load_vae()
params = []
# only set last 2 layers to trainable
for param in self.vae.decoder.parameters():
param.requires_grad = False
train_all = 'all' in self.blocks_to_train
if train_all:
params = list(self.vae.decoder.parameters())
self.vae.decoder.requires_grad_(True)
else:
# mid_block
if train_all or 'mid_block' in self.blocks_to_train:
params += list(self.vae.decoder.mid_block.parameters())
self.vae.decoder.mid_block.requires_grad_(True)
# up_blocks
if train_all or 'up_blocks' in self.blocks_to_train:
params += list(self.vae.decoder.up_blocks.parameters())
self.vae.decoder.up_blocks.requires_grad_(True)
# conv_out (single conv layer output)
if train_all or 'conv_out' in self.blocks_to_train:
params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
self.setup_vgg19()
self.vgg_19.requires_grad_(False)
self.vgg_19.eval()
if self.use_critic:
self.critic.setup()
if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg')
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
optimizer_params=self.optimizer_params)
# setup scheduler
# todo allow other schedulers
scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer,
total_iters=num_steps,
factor=1,
verbose=False
)
# setup tqdm progress bar
self.progress_bar = tqdm(
total=num_steps,
desc='Training VAE',
leave=True
)
# sample first
self.sample()
blank_losses = OrderedDict({
"total": [],
"lpips": [],
"style": [],
"content": [],
"mse": [],
"kl": [],
"tv": [],
"ptn": [],
"crD": [],
"crG": [],
})
epoch_losses = copy.deepcopy(blank_losses)
log_losses = copy.deepcopy(blank_losses)
# range start at self.epoch_num go to self.epochs
for epoch in range(self.epoch_num, self.epochs, 1):
if self.step_num >= self.max_steps:
break
for batch in self.data_loader:
if self.step_num >= self.max_steps:
break
with torch.no_grad():
batch = batch.to(self.device, dtype=self.torch_dtype)
# resize so it matches size of vae evenly
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
# forward pass
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
latents.detach().requires_grad_(True)
pred = self.vae.decode(latents).sample
with torch.no_grad():
show_tensors(
pred.clamp(-1, 1).clone(),
"combined tensor"
)
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, batch], dim=0)
stacked = (stacked / 2 + 0.5).clamp(0, 1)
self.vgg_19(stacked)
if self.use_critic:
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
else:
critic_d_loss = 0.0
style_loss = self.get_style_loss() * self.style_weight
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
if self.lpips_weight > 0:
lpips_loss = self.lpips_loss(
pred.clamp(-1, 1),
batch.clamp(-1, 1)
).mean() * self.lpips_weight
else:
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
if self.lpips_weight > 0:
max_target = lpips_loss.abs() * 0.1
with torch.no_grad():
crit_g_scaler = 1.0
if critic_gen_loss.abs() > max_target:
crit_g_scaler = max_target / critic_gen_loss.abs()
critic_gen_loss *= crit_g_scaler
else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# update progress bar
loss_value = loss.item()
# get exponent like 3.54e-4
loss_string = f"loss: {loss_value:.2e}"
if self.lpips_weight > 0:
loss_string += f" lpips: {lpips_loss.item():.2e}"
if self.content_weight > 0:
loss_string += f" cnt: {content_loss.item():.2e}"
if self.style_weight > 0:
loss_string += f" sty: {style_loss.item():.2e}"
if self.kld_weight > 0:
loss_string += f" kld: {kld_loss.item():.2e}"
if self.mse_weight > 0:
loss_string += f" mse: {mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
if self.pattern_weight > 0:
loss_string += f" ptn: {pattern_loss.item():.2e}"
if self.use_critic and self.critic_weight > 0:
loss_string += f" crG: {critic_gen_loss.item():.2e}"
if self.use_critic:
loss_string += f" crD: {critic_d_loss:.2e}"
if self.optimizer_type.startswith('dadaptation') or \
self.optimizer_type.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
lr_critic_string = ''
if self.use_critic:
lr_critic = self.critic.get_lr()
lr_critic_string = f" lrC: {lr_critic:.1e}"
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
self.progress_bar.set_description(f"E: {epoch}")
self.progress_bar.update(1)
epoch_losses["total"].append(loss_value)
epoch_losses["lpips"].append(lpips_loss.item())
epoch_losses["style"].append(style_loss.item())
epoch_losses["content"].append(content_loss.item())
epoch_losses["mse"].append(mse_loss.item())
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
epoch_losses["ptn"].append(pattern_loss.item())
epoch_losses["crG"].append(critic_gen_loss.item())
epoch_losses["crD"].append(critic_d_loss)
log_losses["total"].append(loss_value)
log_losses["lpips"].append(lpips_loss.item())
log_losses["style"].append(style_loss.item())
log_losses["content"].append(content_loss.item())
log_losses["mse"].append(mse_loss.item())
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
log_losses["ptn"].append(pattern_loss.item())
log_losses["crG"].append(critic_gen_loss.item())
log_losses["crD"].append(critic_d_loss)
# don't do on first step
if self.step_num != start_step:
if self.sample_every and self.step_num % self.sample_every == 0:
# print above the progress bar
self.print(f"Sampling at step {self.step_num}")
self.sample(self.step_num)
if self.save_every and self.step_num % self.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.log_every and self.step_num % self.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
for key in log_losses:
log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
# if log_losses[key] > 0:
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
# reset log losses
log_losses = copy.deepcopy(blank_losses)
self.step_num += 1
# end epoch
if self.writer is not None:
eps = 1e-6
# get avg loss
for key in epoch_losses:
epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
if epoch_losses[key] > 0:
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
# reset epoch losses
epoch_losses = copy.deepcopy(blank_losses)
self.save()