SwinTExCo / train_ddp.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
27.8 kB
import os
import sys
import wandb
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
from datetime import datetime
from zoneinfo import ZoneInfo
from time import gmtime, strftime
from collections import OrderedDict
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision.transforms import CenterCrop
from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as torch_transforms
from torchvision.utils import make_grid
from src.losses import (
ContextualLoss,
ContextualLoss_forward,
Perceptual_loss,
consistent_loss_fn,
discriminator_loss_fn,
generator_loss_fn,
l1_loss_fn,
smoothness_loss_fn,
)
from src.models.CNN.GAN_models import Discriminator_x64
from src.models.CNN.ColorVidNet import ColorVidNet
from src.models.CNN.FrameColor import frame_colorization
from src.models.CNN.NonlocalNet import WeightedAverage_color, NonlocalWeightedAverage, WarpNet, WarpNet_new
from src.models.vit.embed import EmbedModel
from src.models.vit.config import load_config
from src.data import transforms
from src.data.dataloader import VideosDataset, VideosDataset_ImageNet
from src.utils import CenterPad_threshold
from src.utils import (
TimeHandler,
RGB2Lab,
ToTensor,
Normalize,
LossHandler,
WarpingLayer,
uncenter_l,
tensor_lab2rgb,
print_num_params,
SquaredPadding
)
from src.scheduler import PolynomialLR
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
parser = argparse.ArgumentParser()
parser.add_argument("--video_data_root_list", type=str, default="dataset")
parser.add_argument("--flow_data_root_list", type=str, default='flow')
parser.add_argument("--mask_data_root_list", type=str, default='mask')
parser.add_argument("--data_root_imagenet", default="imagenet", type=str)
parser.add_argument("--annotation_file_path", default="dataset/annotation.csv", type=str)
parser.add_argument("--imagenet_pairs_file", default="imagenet_pairs.txt", type=str)
parser.add_argument("--gpu_ids", type=str, default="0,1,2,3", help="separate by comma")
parser.add_argument("--workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--image_size", type=int, default=[384, 384])
parser.add_argument("--ic", type=int, default=7)
parser.add_argument("--epoch", type=int, default=40)
parser.add_argument("--resume_epoch", type=int, default=0)
parser.add_argument("--resume", action='store_true')
parser.add_argument("--load_pretrained_model", action='store_true')
parser.add_argument("--pretrained_model_dir", type=str, default='ckpt')
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--beta1", type=float, default=0.5)
parser.add_argument("--lr_step", type=int, default=1)
parser.add_argument("--lr_gamma", type=float, default=0.9)
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
parser.add_argument("--checkpoint_step", type=int, default=500)
parser.add_argument("--real_reference_probability", type=float, default=0.7)
parser.add_argument("--nonzero_placeholder_probability", type=float, default=0.0)
parser.add_argument("--domain_invariant", action='store_true')
parser.add_argument("--weigth_l1", type=float, default=2.0)
parser.add_argument("--weight_contextual", type=float, default="0.5")
parser.add_argument("--weight_perceptual", type=float, default="0.02")
parser.add_argument("--weight_smoothness", type=float, default="5.0")
parser.add_argument("--weight_gan", type=float, default="0.5")
parser.add_argument("--weight_nonlocal_smoothness", type=float, default="0.0")
parser.add_argument("--weight_nonlocal_consistent", type=float, default="0.0")
parser.add_argument("--weight_consistent", type=float, default="0.05")
parser.add_argument("--luminance_noise", type=float, default="2.0")
parser.add_argument("--permute_data", action='store_true')
parser.add_argument("--contextual_loss_direction", type=str, default="forward", help="forward or backward matching")
parser.add_argument("--batch_accum_size", type=int, default=10)
parser.add_argument("--epoch_train_discriminator", type=int, default=3)
parser.add_argument("--vit_version", type=str, default="vit_tiny_patch16_384")
parser.add_argument("--use_dummy", action='store_true')
parser.add_argument("--use_wandb", action='store_true')
parser.add_argument("--use_feature_transform", action='store_true')
parser.add_argument("--head_out_idx", type=str, default="8,9,10,11")
parser.add_argument("--wandb_token", type=str, default="")
parser.add_argument("--wandb_name", type=str, default="")
def ddp_setup():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ['LOCAL_RANK'])
return local_rank
def ddp_cleanup():
dist.destroy_process_group()
def prepare_dataloader_ddp(dataset, batch_size=4, pin_memory=False, num_workers=0):
sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(dataset,
batch_size=batch_size,
pin_memory=pin_memory,
num_workers=num_workers,
sampler=sampler)
return dataloader
def is_master_process():
ddp_rank = int(os.environ['RANK'])
return ddp_rank == 0
def load_data():
transforms_video = [
SquaredPadding(target_size=opt.image_size[0]),
RGB2Lab(),
ToTensor(),
Normalize(),
]
train_dataset_videos = [
VideosDataset(
video_data_root=video_data_root,
flow_data_root=flow_data_root,
mask_data_root=mask_data_root,
imagenet_folder=opt.data_root_imagenet,
annotation_file_path=opt.annotation_file_path,
image_size=opt.image_size,
image_transform=torch_transforms.Compose(transforms_video),
real_reference_probability=opt.real_reference_probability,
nonzero_placeholder_probability=opt.nonzero_placeholder_probability,
)
for video_data_root, flow_data_root, mask_data_root in zip(opt.video_data_root_list, opt.flow_data_root_list, opt.mask_data_root_list)
]
transforms_imagenet = [SquaredPadding(target_size=opt.image_size[0]), RGB2Lab(), ToTensor(), Normalize()]
extra_reference_transform = [
torch_transforms.RandomHorizontalFlip(0.5),
torch_transforms.RandomResizedCrop(480, (0.98, 1.0), ratio=(0.8, 1.2)),
]
train_dataset_imagenet = VideosDataset_ImageNet(
imagenet_data_root=opt.data_root_imagenet,
pairs_file=opt.imagenet_pairs_file,
image_size=opt.image_size,
transforms_imagenet=transforms_imagenet,
distortion_level=4,
brightnessjitter=5,
nonzero_placeholder_probability=opt.nonzero_placeholder_probability,
extra_reference_transform=extra_reference_transform,
real_reference_probability=opt.real_reference_probability,
)
dataset_combined = ConcatDataset(train_dataset_videos + [train_dataset_imagenet])
data_loader = prepare_dataloader_ddp(dataset_combined,
batch_size=opt.batch_size,
pin_memory=False,
num_workers=opt.workers)
return data_loader
def save_checkpoints(saved_path):
# Make directory if the folder doesn't exists
os.makedirs(saved_path, exist_ok=True)
# Save model
torch.save(
nonlocal_net.module.state_dict(),
os.path.join(saved_path, "nonlocal_net.pth"),
)
torch.save(
colornet.module.state_dict(),
os.path.join(saved_path, "colornet.pth"),
)
torch.save(
discriminator.module.state_dict(),
os.path.join(saved_path, "discriminator.pth"),
)
torch.save(
embed_net.state_dict(),
os.path.join(saved_path, "embed_net.pth")
)
# Save learning state for restoring train
learning_state = {
"epoch": epoch_num,
"total_iter": total_iter,
"optimizer_g": optimizer_g.state_dict(),
"optimizer_d": optimizer_d.state_dict(),
"optimizer_schedule_g": step_optim_scheduler_g.state_dict(),
"optimizer_schedule_d": step_optim_scheduler_d.state_dict(),
}
torch.save(learning_state, os.path.join(saved_path, "learning_state.pth"))
def training_logger():
if (total_iter % opt.checkpoint_step == 0) or (total_iter == len(data_loader)):
train_loss_dict = {"train/" + str(k): v / loss_handler.count_sample for k, v in loss_handler.loss_dict.items()}
train_loss_dict["train/opt_g_lr_1"] = step_optim_scheduler_g.get_last_lr()[0]
train_loss_dict["train/opt_g_lr_2"] = step_optim_scheduler_g.get_last_lr()[1]
train_loss_dict["train/opt_d_lr"] = step_optim_scheduler_d.get_last_lr()[0]
alert_text = f"l1_loss: {l1_loss.item()}\npercep_loss: {perceptual_loss.item()}\nctx_loss: {contextual_loss_total.item()}\ncst_loss: {consistent_loss.item()}\nsm_loss: {smoothness_loss.item()}\ntotal: {total_loss.item()}"
if opt.use_wandb:
wandb.log(train_loss_dict)
wandb.alert(title=f"Progress training #{total_iter}", text=alert_text)
for idx in range(I_predict_rgb.shape[0]):
concated_I = make_grid(
[(I_predict_rgb[idx] * 255), (I_reference_rgb[idx] * 255), (I_current_rgb[idx] * 255)], nrow=3
)
wandb_concated_I = wandb.Image(
concated_I,
caption="[LEFT] Predict, [CENTER] Reference, [RIGHT] Ground truth\n[REF] {}, [FRAME] {}".format(
ref_path[idx], curr_frame_path[idx]
),
)
wandb.log({f"example_{idx}": wandb_concated_I})
# Save learning state checkpoint
# save_checkpoints(os.path.join(opt.checkpoint_dir, 'runs'))
loss_handler.reset()
def load_params(ckpt_file, local_rank, has_module=False):
params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}')
new_params = []
for key, value in params.items():
new_params.append(("module."+key if has_module else key, value))
return OrderedDict(new_params)
def parse(parser, save=True):
opt = parser.parse_args()
args = vars(opt)
print("------------------------------ Options -------------------------------")
for k, v in sorted(args.items()):
print("%s: %s" % (str(k), str(v)))
print("-------------------------------- End ---------------------------------")
if save:
file_name = os.path.join("opt.txt")
with open(file_name, "wt") as opt_file:
opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
opt_file.write("------------------------------ Options -------------------------------\n")
for k, v in sorted(args.items()):
opt_file.write("%s: %s\n" % (str(k), str(v)))
opt_file.write("-------------------------------- End ---------------------------------\n")
return opt
def gpu_setup():
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
cudnn.benchmark = True
torch.cuda.set_device(opt.gpu_ids[0])
device = torch.device("cuda")
print("running on GPU", opt.gpu_ids)
return device
if __name__ == "__main__":
############################################## SETUP ###############################################
torch.multiprocessing.set_start_method("spawn", force=True)
# =============== GET PARSER OPTION ================
opt = parse(parser)
opt.video_data_root_list = opt.video_data_root_list.split(",")
opt.flow_data_root_list = opt.flow_data_root_list.split(",")
opt.mask_data_root_list = opt.mask_data_root_list.split(",")
opt.gpu_ids = list(map(int, opt.gpu_ids.split(",")))
opt.head_out_idx = list(map(int, opt.head_out_idx.split(",")))
n_dim_output = 3 if opt.use_feature_transform else 4
assert len(opt.head_out_idx) == 4, "Size of head_out_idx must be 4"
# =================== INIT WANDB ===================
# if is_master_process():
if opt.use_wandb:
print("Save images to Wandb")
if opt.wandb_token != "":
try:
wandb.login(key=opt.wandb_token)
except:
pass
if opt.use_wandb:
wandb.init(
project="video-colorization",
group=f"{opt.wandb_name} {datetime.now(tz=ZoneInfo('Asia/Ho_Chi_Minh')).strftime('%Y/%m/%d_%H-%M-%S')}",
#group="DDP"
)
# ================== SETUP DEVICE ==================
local_rank = ddp_setup()
# =================== VIT CONFIG ===================
cfg = load_config()
model_cfg = cfg["model"][opt.vit_version]
model_cfg["image_size"] = (384, 384)
model_cfg["backbone"] = opt.vit_version
model_cfg["dropout"] = 0.0
model_cfg["drop_path_rate"] = 0.1
model_cfg["n_cls"] = 10
############################################ LOAD DATA #############################################
data_loader = load_data()
########################################## DEFINE NETWORK ##########################################
colornet = DDP(ColorVidNet(opt.ic).to(local_rank), device_ids=[local_rank], output_device=local_rank)
if opt.use_feature_transform:
nonlocal_net = DDP(WarpNet().to(local_rank), device_ids=[local_rank], output_device=local_rank)
else:
nonlocal_net = DDP(WarpNet_new(model_cfg["d_model"]).to(local_rank), device_ids=[local_rank], output_device=local_rank)
discriminator = DDP(Discriminator_x64(ndf=64).to(local_rank), device_ids=[local_rank], output_device=local_rank)
weighted_layer_color = WeightedAverage_color().to(local_rank)
nonlocal_weighted_layer = NonlocalWeightedAverage().to(local_rank)
warping_layer = WarpingLayer(device=local_rank).to(local_rank)
embed_net = EmbedModel(model_cfg, head_out_idx=opt.head_out_idx, n_dim_output=n_dim_output, device=local_rank)
if is_master_process():
# Print number of parameters
print("-" * 59)
print("| TYPE | Model name | Num params |")
print("-" * 59)
colornet_params = print_num_params(colornet)
nonlocal_net_params = print_num_params(nonlocal_net)
discriminator_params = print_num_params(discriminator)
weighted_layer_color_params = print_num_params(weighted_layer_color)
nonlocal_weighted_layer_params = print_num_params(nonlocal_weighted_layer)
warping_layer_params = print_num_params(warping_layer)
embed_net_params = print_num_params(embed_net)
print("-" * 59)
print(
f"| TOTAL | | {('{:,}'.format(colornet_params+nonlocal_net_params+discriminator_params+weighted_layer_color_params+nonlocal_weighted_layer_params+warping_layer_params+embed_net_params)).rjust(10)} |"
)
print("-" * 59)
if opt.use_wandb:
wandb.watch(discriminator, log="all", log_freq=opt.checkpoint_step, idx=0)
wandb.watch(embed_net, log="all", log_freq=opt.checkpoint_step, idx=1)
wandb.watch(colornet, log="all", log_freq=opt.checkpoint_step, idx=2)
wandb.watch(nonlocal_net, log="all", log_freq=opt.checkpoint_step, idx=3)
###################################### DEFINE LOSS FUNCTIONS #######################################
perceptual_loss_fn = Perceptual_loss(opt.domain_invariant, opt.weight_perceptual)
contextual_loss = ContextualLoss().to(local_rank)
contextual_forward_loss = ContextualLoss_forward().to(local_rank)
######################################## DEFINE OPTIMIZERS #########################################
optimizer_g = optim.AdamW(
[
{"params": nonlocal_net.parameters(), "lr": opt.lr},
{"params": colornet.parameters(), "lr": 2 * opt.lr},
{"params": embed_net.parameters(), "lr": opt.lr},
],
betas=(0.5, 0.999),
eps=1e-5,
amsgrad=True,
)
optimizer_d = optim.AdamW(
filter(lambda p: p.requires_grad, discriminator.parameters()),
lr=opt.lr,
betas=(0.5, 0.999),
amsgrad=True,
)
step_optim_scheduler_g = PolynomialLR(
optimizer_g,
step_size=opt.lr_step,
iter_warmup=0,
iter_max=len(data_loader) * opt.epoch,
power=0.9,
min_lr=1e-8
)
step_optim_scheduler_d = PolynomialLR(
optimizer_d,
step_size=opt.lr_step,
iter_warmup=0,
iter_max=len(data_loader) * opt.epoch,
power=0.9,
min_lr=1e-8
)
########################################## DEFINE OTHERS ###########################################
downsampling_by2 = nn.AvgPool2d(kernel_size=2).to(local_rank)
# timer_handler = TimeHandler()
loss_handler = LossHandler()
############################################## TRAIN ###############################################
# ============= USE PRETRAINED OR NOT ==============
if opt.load_pretrained_model:
nonlocal_net.load_state_dict(load_params(os.path.join(opt.pretrained_model_dir, "nonlocal_net.pth"),
local_rank,
has_module=True))
colornet.load_state_dict(load_params(os.path.join(opt.pretrained_model_dir, "colornet.pth"),
local_rank,
has_module=True))
discriminator.load_state_dict(load_params(os.path.join(opt.pretrained_model_dir, "discriminator.pth"),
local_rank,
has_module=True))
embed_net_params = load_params(os.path.join(opt.pretrained_model_dir, "embed_net.pth"),
local_rank,
has_module=False)
if "module.vit.heads_out" in embed_net_params:
embed_net_params.pop("module.vit.heads_out")
elif "vit.heads_out" in embed_net_params:
embed_net_params.pop("vit.heads_out")
embed_net.load_state_dict(embed_net_params)
learning_checkpoint = torch.load(os.path.join(opt.pretrained_model_dir, "learning_state.pth"))
optimizer_g.load_state_dict(learning_checkpoint["optimizer_g"])
optimizer_d.load_state_dict(learning_checkpoint["optimizer_d"])
step_optim_scheduler_g.load_state_dict(learning_checkpoint["optimizer_schedule_g"])
step_optim_scheduler_d.load_state_dict(learning_checkpoint["optimizer_schedule_d"])
total_iter = learning_checkpoint['total_iter']
start_epoch = learning_checkpoint['epoch']+1
else:
total_iter = 0
start_epoch = 1
for epoch_num in range(start_epoch, opt.epoch+1):
data_loader.sampler.set_epoch(epoch_num-1)
if is_master_process():
train_progress_bar = tqdm(
data_loader,
desc =f'Epoch {epoch_num}[Training]',
position = 0,
leave = False
)
else:
train_progress_bar = data_loader
for iter, sample in enumerate(train_progress_bar):
# timer_handler.compute_time("load_sample")
total_iter += 1
# =============== LOAD DATA SAMPLE ================
(
I_last_lab, ######## (3, H, W)
I_current_lab, ##### (3, H, W)
I_reference_lab, ### (3, H, W)
flow_forward, ###### (2, H, W)
mask, ############## (1, H, W)
placeholder_lab, ### (3, H, W)
self_ref_flag, ##### (3, H, W)
prev_frame_path,
curr_frame_path,
ref_path,
) = sample
I_last_lab = I_last_lab.to(local_rank)
I_current_lab = I_current_lab.to(local_rank)
I_reference_lab = I_reference_lab.to(local_rank)
flow_forward = flow_forward.to(local_rank)
mask = mask.to(local_rank)
placeholder_lab = placeholder_lab.to(local_rank)
self_ref_flag = self_ref_flag.to(local_rank)
I_last_l = I_last_lab[:, 0:1, :, :]
I_last_ab = I_last_lab[:, 1:3, :, :]
I_current_l = I_current_lab[:, 0:1, :, :]
I_current_ab = I_current_lab[:, 1:3, :, :]
I_reference_l = I_reference_lab[:, 0:1, :, :]
I_reference_ab = I_reference_lab[:, 1:3, :, :]
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1))
# _load_sample_time = timer_handler.compute_time("load_sample")
# timer_handler.compute_time("forward_model")
features_B = embed_net(I_reference_rgb)
_, B_feat_1, B_feat_2, B_feat_3 = features_B
# ================== COLORIZATION ==================
# The last frame
I_last_ab_predict, I_last_nonlocal_lab_predict = frame_colorization(
IA_l=I_last_l,
IB_lab=I_reference_lab,
IA_last_lab=placeholder_lab,
features_B=features_B,
embed_net=embed_net,
colornet=colornet,
nonlocal_net=nonlocal_net,
luminance_noise=opt.luminance_noise,
)
I_last_lab_predict = torch.cat((I_last_l, I_last_ab_predict), dim=1)
# The current frame
I_current_ab_predict, I_current_nonlocal_lab_predict = frame_colorization(
IA_l=I_current_l,
IB_lab=I_reference_lab,
IA_last_lab=I_last_lab_predict,
features_B=features_B,
embed_net=embed_net,
colornet=colornet,
nonlocal_net=nonlocal_net,
luminance_noise=opt.luminance_noise,
)
I_current_lab_predict = torch.cat((I_last_l, I_current_ab_predict), dim=1)
# ================ UPDATE GENERATOR ================
if opt.weight_gan > 0:
optimizer_g.zero_grad()
optimizer_d.zero_grad()
fake_data_lab = torch.cat(
(
uncenter_l(I_current_l),
I_current_ab_predict,
uncenter_l(I_last_l),
I_last_ab_predict,
),
dim=1,
)
real_data_lab = torch.cat(
(
uncenter_l(I_current_l),
I_current_ab,
uncenter_l(I_last_l),
I_last_ab,
),
dim=1,
)
if opt.permute_data:
batch_index = torch.arange(-1, opt.batch_size - 1, dtype=torch.long)
real_data_lab = real_data_lab[batch_index, ...]
discriminator_loss = discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator)
discriminator_loss.backward()
optimizer_d.step()
optimizer_g.zero_grad()
optimizer_d.zero_grad()
# ================== COMPUTE LOSS ==================
# L1 loss
l1_loss = l1_loss_fn(I_current_ab, I_current_ab_predict) * opt.weigth_l1
# Generator_loss. TODO: freeze this to train some first epoch
if epoch_num > opt.epoch_train_discriminator:
generator_loss = generator_loss_fn(real_data_lab, fake_data_lab, discriminator, opt.weight_gan, local_rank)
# Perceptual Loss
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1))
_, pred_feat_1, pred_feat_2, pred_feat_3 = embed_net(I_predict_rgb)
I_current_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab), dim=1))
A_feat_0, _, _, A_feat_3 = embed_net(I_current_rgb)
perceptual_loss = perceptual_loss_fn(A_feat_3, pred_feat_3)
# Contextual Loss
contextual_style5_1 = torch.mean(contextual_forward_loss(pred_feat_3, B_feat_3.detach())) * 8
contextual_style4_1 = torch.mean(contextual_forward_loss(pred_feat_2, B_feat_2.detach())) * 4
contextual_style3_1 = torch.mean(contextual_forward_loss(pred_feat_1, B_feat_1.detach())) * 2
contextual_loss_total = (
contextual_style5_1 + contextual_style4_1 + contextual_style3_1
) * opt.weight_contextual
# Consistent Loss
consistent_loss = consistent_loss_fn(
I_current_lab_predict,
I_last_ab_predict,
I_current_nonlocal_lab_predict,
I_last_nonlocal_lab_predict,
flow_forward,
mask,
warping_layer,
weight_consistent=opt.weight_consistent,
weight_nonlocal_consistent=opt.weight_nonlocal_consistent,
device=local_rank,
)
# Smoothness loss
smoothness_loss = smoothness_loss_fn(
I_current_l,
I_current_lab,
I_current_ab_predict,
A_feat_0,
weighted_layer_color,
nonlocal_weighted_layer,
weight_smoothness=opt.weight_smoothness,
weight_nonlocal_smoothness=opt.weight_nonlocal_smoothness,
device=local_rank
)
# Total loss
total_loss = l1_loss + perceptual_loss + contextual_loss_total + consistent_loss + smoothness_loss
if epoch_num > opt.epoch_train_discriminator:
total_loss += generator_loss
# Add loss to loss handler
loss_handler.add_loss(key="total_loss", loss=total_loss.item())
loss_handler.add_loss(key="l1_loss", loss=l1_loss.item())
loss_handler.add_loss(key="perceptual_loss", loss=perceptual_loss.item())
loss_handler.add_loss(key="contextual_loss", loss=contextual_loss_total.item())
loss_handler.add_loss(key="consistent_loss", loss=consistent_loss.item())
loss_handler.add_loss(key="smoothness_loss", loss=smoothness_loss.item())
loss_handler.add_loss(key="discriminator_loss", loss=discriminator_loss.item())
if epoch_num > opt.epoch_train_discriminator:
loss_handler.add_loss(key="generator_loss", loss=generator_loss.item())
loss_handler.count_one_sample()
total_loss.backward()
optimizer_g.step()
step_optim_scheduler_g.step()
step_optim_scheduler_d.step()
# _forward_model_time = timer_handler.compute_time("forward_model")
# timer_handler.compute_time("training_logger")
training_logger()
# _training_logger_time = timer_handler.compute_time("training_logger")
####
if is_master_process():
save_checkpoints(os.path.join(opt.checkpoint_dir, f"epoch_{epoch_num}"))
####
if opt.use_wandb:
wandb.finish()
ddp_cleanup()