Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import wandb | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from PIL import Image | |
from datetime import datetime | |
from zoneinfo import ZoneInfo | |
from time import gmtime, strftime | |
from collections import OrderedDict | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.backends.cudnn as cudnn | |
from torchvision.transforms import CenterCrop | |
from torch.utils.data import ConcatDataset, DataLoader, WeightedRandomSampler | |
import torchvision.transforms as torch_transforms | |
from torchvision.utils import make_grid | |
from src.losses import ( | |
ContextualLoss, | |
ContextualLoss_forward, | |
Perceptual_loss, | |
consistent_loss_fn, | |
discriminator_loss_fn, | |
generator_loss_fn, | |
l1_loss_fn, | |
smoothness_loss_fn, | |
) | |
from src.models.CNN.GAN_models import Discriminator_x64 | |
from src.models.CNN.ColorVidNet import ColorVidNet | |
from src.models.CNN.FrameColor import frame_colorization | |
from src.models.CNN.NonlocalNet import WeightedAverage_color, NonlocalWeightedAverage, WarpNet, WarpNet_new | |
from src.models.vit.embed import EmbedModel | |
from src.models.vit.config import load_config | |
from src.data import transforms | |
from src.data.dataloader import VideosDataset, VideosDataset_ImageNet | |
from src.utils import CenterPad_threshold | |
from src.utils import ( | |
TimeHandler, | |
RGB2Lab, | |
ToTensor, | |
Normalize, | |
LossHandler, | |
WarpingLayer, | |
uncenter_l, | |
tensor_lab2rgb, | |
print_num_params, | |
SquaredPadding | |
) | |
from src.scheduler import PolynomialLR | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torch.distributed as dist | |
from torch.utils.data.distributed import DistributedSampler | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--video_data_root_list", type=str, default="dataset") | |
parser.add_argument("--flow_data_root_list", type=str, default='flow') | |
parser.add_argument("--mask_data_root_list", type=str, default='mask') | |
parser.add_argument("--data_root_imagenet", default="imagenet", type=str) | |
parser.add_argument("--annotation_file_path", default="dataset/annotation.csv", type=str) | |
parser.add_argument("--imagenet_pairs_file", default="imagenet_pairs.txt", type=str) | |
parser.add_argument("--gpu_ids", type=str, default="0,1,2,3", help="separate by comma") | |
parser.add_argument("--workers", type=int, default=0) | |
parser.add_argument("--batch_size", type=int, default=2) | |
parser.add_argument("--image_size", type=int, default=[384, 384]) | |
parser.add_argument("--ic", type=int, default=7) | |
parser.add_argument("--epoch", type=int, default=40) | |
parser.add_argument("--resume_epoch", type=int, default=0) | |
parser.add_argument("--resume", action='store_true') | |
parser.add_argument("--load_pretrained_model", action='store_true') | |
parser.add_argument("--pretrained_model_dir", type=str, default='ckpt') | |
parser.add_argument("--lr", type=float, default=1e-4) | |
parser.add_argument("--beta1", type=float, default=0.5) | |
parser.add_argument("--lr_step", type=int, default=1) | |
parser.add_argument("--lr_gamma", type=float, default=0.9) | |
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints") | |
parser.add_argument("--checkpoint_step", type=int, default=500) | |
parser.add_argument("--real_reference_probability", type=float, default=0.7) | |
parser.add_argument("--nonzero_placeholder_probability", type=float, default=0.0) | |
parser.add_argument("--domain_invariant", action='store_true') | |
parser.add_argument("--weigth_l1", type=float, default=2.0) | |
parser.add_argument("--weight_contextual", type=float, default="0.5") | |
parser.add_argument("--weight_perceptual", type=float, default="0.02") | |
parser.add_argument("--weight_smoothness", type=float, default="5.0") | |
parser.add_argument("--weight_gan", type=float, default="0.5") | |
parser.add_argument("--weight_nonlocal_smoothness", type=float, default="0.0") | |
parser.add_argument("--weight_nonlocal_consistent", type=float, default="0.0") | |
parser.add_argument("--weight_consistent", type=float, default="0.05") | |
parser.add_argument("--luminance_noise", type=float, default="2.0") | |
parser.add_argument("--permute_data", action='store_true') | |
parser.add_argument("--contextual_loss_direction", type=str, default="forward", help="forward or backward matching") | |
parser.add_argument("--batch_accum_size", type=int, default=10) | |
parser.add_argument("--epoch_train_discriminator", type=int, default=3) | |
parser.add_argument("--vit_version", type=str, default="vit_tiny_patch16_384") | |
parser.add_argument("--use_dummy", action='store_true') | |
parser.add_argument("--use_wandb", action='store_true') | |
parser.add_argument("--use_feature_transform", action='store_true') | |
parser.add_argument("--head_out_idx", type=str, default="8,9,10,11") | |
parser.add_argument("--wandb_token", type=str, default="") | |
parser.add_argument("--wandb_name", type=str, default="") | |
def ddp_setup(): | |
dist.init_process_group(backend="nccl") | |
local_rank = int(os.environ['LOCAL_RANK']) | |
return local_rank | |
def ddp_cleanup(): | |
dist.destroy_process_group() | |
def prepare_dataloader_ddp(dataset, batch_size=4, pin_memory=False, num_workers=0): | |
sampler = DistributedSampler(dataset, shuffle=True) | |
dataloader = DataLoader(dataset, | |
batch_size=batch_size, | |
pin_memory=pin_memory, | |
num_workers=num_workers, | |
sampler=sampler) | |
return dataloader | |
def is_master_process(): | |
ddp_rank = int(os.environ['RANK']) | |
return ddp_rank == 0 | |
def load_data(): | |
transforms_video = [ | |
SquaredPadding(target_size=opt.image_size[0]), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize(), | |
] | |
train_dataset_videos = [ | |
VideosDataset( | |
video_data_root=video_data_root, | |
flow_data_root=flow_data_root, | |
mask_data_root=mask_data_root, | |
imagenet_folder=opt.data_root_imagenet, | |
annotation_file_path=opt.annotation_file_path, | |
image_size=opt.image_size, | |
image_transform=torch_transforms.Compose(transforms_video), | |
real_reference_probability=opt.real_reference_probability, | |
nonzero_placeholder_probability=opt.nonzero_placeholder_probability, | |
) | |
for video_data_root, flow_data_root, mask_data_root in zip(opt.video_data_root_list, opt.flow_data_root_list, opt.mask_data_root_list) | |
] | |
transforms_imagenet = [SquaredPadding(target_size=opt.image_size[0]), RGB2Lab(), ToTensor(), Normalize()] | |
extra_reference_transform = [ | |
torch_transforms.RandomHorizontalFlip(0.5), | |
torch_transforms.RandomResizedCrop(480, (0.98, 1.0), ratio=(0.8, 1.2)), | |
] | |
train_dataset_imagenet = VideosDataset_ImageNet( | |
imagenet_data_root=opt.data_root_imagenet, | |
pairs_file=opt.imagenet_pairs_file, | |
image_size=opt.image_size, | |
transforms_imagenet=transforms_imagenet, | |
distortion_level=4, | |
brightnessjitter=5, | |
nonzero_placeholder_probability=opt.nonzero_placeholder_probability, | |
extra_reference_transform=extra_reference_transform, | |
real_reference_probability=opt.real_reference_probability, | |
) | |
dataset_combined = ConcatDataset(train_dataset_videos + [train_dataset_imagenet]) | |
data_loader = prepare_dataloader_ddp(dataset_combined, | |
batch_size=opt.batch_size, | |
pin_memory=False, | |
num_workers=opt.workers) | |
return data_loader | |
def save_checkpoints(saved_path): | |
# Make directory if the folder doesn't exists | |
os.makedirs(saved_path, exist_ok=True) | |
# Save model | |
torch.save( | |
nonlocal_net.module.state_dict(), | |
os.path.join(saved_path, "nonlocal_net.pth"), | |
) | |
torch.save( | |
colornet.module.state_dict(), | |
os.path.join(saved_path, "colornet.pth"), | |
) | |
torch.save( | |
discriminator.module.state_dict(), | |
os.path.join(saved_path, "discriminator.pth"), | |
) | |
torch.save( | |
embed_net.state_dict(), | |
os.path.join(saved_path, "embed_net.pth") | |
) | |
# Save learning state for restoring train | |
learning_state = { | |
"epoch": epoch_num, | |
"total_iter": total_iter, | |
"optimizer_g": optimizer_g.state_dict(), | |
"optimizer_d": optimizer_d.state_dict(), | |
"optimizer_schedule_g": step_optim_scheduler_g.state_dict(), | |
"optimizer_schedule_d": step_optim_scheduler_d.state_dict(), | |
} | |
torch.save(learning_state, os.path.join(saved_path, "learning_state.pth")) | |
def training_logger(): | |
if (total_iter % opt.checkpoint_step == 0) or (total_iter == len(data_loader)): | |
train_loss_dict = {"train/" + str(k): v / loss_handler.count_sample for k, v in loss_handler.loss_dict.items()} | |
train_loss_dict["train/opt_g_lr_1"] = step_optim_scheduler_g.get_last_lr()[0] | |
train_loss_dict["train/opt_g_lr_2"] = step_optim_scheduler_g.get_last_lr()[1] | |
train_loss_dict["train/opt_d_lr"] = step_optim_scheduler_d.get_last_lr()[0] | |
alert_text = f"l1_loss: {l1_loss.item()}\npercep_loss: {perceptual_loss.item()}\nctx_loss: {contextual_loss_total.item()}\ncst_loss: {consistent_loss.item()}\nsm_loss: {smoothness_loss.item()}\ntotal: {total_loss.item()}" | |
if opt.use_wandb: | |
wandb.log(train_loss_dict) | |
wandb.alert(title=f"Progress training #{total_iter}", text=alert_text) | |
for idx in range(I_predict_rgb.shape[0]): | |
concated_I = make_grid( | |
[(I_predict_rgb[idx] * 255), (I_reference_rgb[idx] * 255), (I_current_rgb[idx] * 255)], nrow=3 | |
) | |
wandb_concated_I = wandb.Image( | |
concated_I, | |
caption="[LEFT] Predict, [CENTER] Reference, [RIGHT] Ground truth\n[REF] {}, [FRAME] {}".format( | |
ref_path[idx], curr_frame_path[idx] | |
), | |
) | |
wandb.log({f"example_{idx}": wandb_concated_I}) | |
# Save learning state checkpoint | |
# save_checkpoints(os.path.join(opt.checkpoint_dir, 'runs')) | |
loss_handler.reset() | |
def load_params(ckpt_file, local_rank, has_module=False): | |
params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}') | |
new_params = [] | |
for key, value in params.items(): | |
new_params.append(("module."+key if has_module else key, value)) | |
return OrderedDict(new_params) | |
def parse(parser, save=True): | |
opt = parser.parse_args() | |
args = vars(opt) | |
print("------------------------------ Options -------------------------------") | |
for k, v in sorted(args.items()): | |
print("%s: %s" % (str(k), str(v))) | |
print("-------------------------------- End ---------------------------------") | |
if save: | |
file_name = os.path.join("opt.txt") | |
with open(file_name, "wt") as opt_file: | |
opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n") | |
opt_file.write("------------------------------ Options -------------------------------\n") | |
for k, v in sorted(args.items()): | |
opt_file.write("%s: %s\n" % (str(k), str(v))) | |
opt_file.write("-------------------------------- End ---------------------------------\n") | |
return opt | |
def gpu_setup(): | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
cudnn.benchmark = True | |
torch.cuda.set_device(opt.gpu_ids[0]) | |
device = torch.device("cuda") | |
print("running on GPU", opt.gpu_ids) | |
return device | |
if __name__ == "__main__": | |
############################################## SETUP ############################################### | |
torch.multiprocessing.set_start_method("spawn", force=True) | |
# =============== GET PARSER OPTION ================ | |
opt = parse(parser) | |
opt.video_data_root_list = opt.video_data_root_list.split(",") | |
opt.flow_data_root_list = opt.flow_data_root_list.split(",") | |
opt.mask_data_root_list = opt.mask_data_root_list.split(",") | |
opt.gpu_ids = list(map(int, opt.gpu_ids.split(","))) | |
opt.head_out_idx = list(map(int, opt.head_out_idx.split(","))) | |
n_dim_output = 3 if opt.use_feature_transform else 4 | |
assert len(opt.head_out_idx) == 4, "Size of head_out_idx must be 4" | |
# =================== INIT WANDB =================== | |
# if is_master_process(): | |
if opt.use_wandb: | |
print("Save images to Wandb") | |
if opt.wandb_token != "": | |
try: | |
wandb.login(key=opt.wandb_token) | |
except: | |
pass | |
if opt.use_wandb: | |
wandb.init( | |
project="video-colorization", | |
group=f"{opt.wandb_name} {datetime.now(tz=ZoneInfo('Asia/Ho_Chi_Minh')).strftime('%Y/%m/%d_%H-%M-%S')}", | |
#group="DDP" | |
) | |
# ================== SETUP DEVICE ================== | |
local_rank = ddp_setup() | |
# =================== VIT CONFIG =================== | |
cfg = load_config() | |
model_cfg = cfg["model"][opt.vit_version] | |
model_cfg["image_size"] = (384, 384) | |
model_cfg["backbone"] = opt.vit_version | |
model_cfg["dropout"] = 0.0 | |
model_cfg["drop_path_rate"] = 0.1 | |
model_cfg["n_cls"] = 10 | |
############################################ LOAD DATA ############################################# | |
data_loader = load_data() | |
########################################## DEFINE NETWORK ########################################## | |
colornet = DDP(ColorVidNet(opt.ic).to(local_rank), device_ids=[local_rank], output_device=local_rank) | |
if opt.use_feature_transform: | |
nonlocal_net = DDP(WarpNet().to(local_rank), device_ids=[local_rank], output_device=local_rank) | |
else: | |
nonlocal_net = DDP(WarpNet_new(model_cfg["d_model"]).to(local_rank), device_ids=[local_rank], output_device=local_rank) | |
discriminator = DDP(Discriminator_x64(ndf=64).to(local_rank), device_ids=[local_rank], output_device=local_rank) | |
weighted_layer_color = WeightedAverage_color().to(local_rank) | |
nonlocal_weighted_layer = NonlocalWeightedAverage().to(local_rank) | |
warping_layer = WarpingLayer(device=local_rank).to(local_rank) | |
embed_net = EmbedModel(model_cfg, head_out_idx=opt.head_out_idx, n_dim_output=n_dim_output, device=local_rank) | |
if is_master_process(): | |
# Print number of parameters | |
print("-" * 59) | |
print("| TYPE | Model name | Num params |") | |
print("-" * 59) | |
colornet_params = print_num_params(colornet) | |
nonlocal_net_params = print_num_params(nonlocal_net) | |
discriminator_params = print_num_params(discriminator) | |
weighted_layer_color_params = print_num_params(weighted_layer_color) | |
nonlocal_weighted_layer_params = print_num_params(nonlocal_weighted_layer) | |
warping_layer_params = print_num_params(warping_layer) | |
embed_net_params = print_num_params(embed_net) | |
print("-" * 59) | |
print( | |
f"| TOTAL | | {('{:,}'.format(colornet_params+nonlocal_net_params+discriminator_params+weighted_layer_color_params+nonlocal_weighted_layer_params+warping_layer_params+embed_net_params)).rjust(10)} |" | |
) | |
print("-" * 59) | |
if opt.use_wandb: | |
wandb.watch(discriminator, log="all", log_freq=opt.checkpoint_step, idx=0) | |
wandb.watch(embed_net, log="all", log_freq=opt.checkpoint_step, idx=1) | |
wandb.watch(colornet, log="all", log_freq=opt.checkpoint_step, idx=2) | |
wandb.watch(nonlocal_net, log="all", log_freq=opt.checkpoint_step, idx=3) | |
###################################### DEFINE LOSS FUNCTIONS ####################################### | |
perceptual_loss_fn = Perceptual_loss(opt.domain_invariant, opt.weight_perceptual) | |
contextual_loss = ContextualLoss().to(local_rank) | |
contextual_forward_loss = ContextualLoss_forward().to(local_rank) | |
######################################## DEFINE OPTIMIZERS ######################################### | |
optimizer_g = optim.AdamW( | |
[ | |
{"params": nonlocal_net.parameters(), "lr": opt.lr}, | |
{"params": colornet.parameters(), "lr": 2 * opt.lr}, | |
{"params": embed_net.parameters(), "lr": opt.lr}, | |
], | |
betas=(0.5, 0.999), | |
eps=1e-5, | |
amsgrad=True, | |
) | |
optimizer_d = optim.AdamW( | |
filter(lambda p: p.requires_grad, discriminator.parameters()), | |
lr=opt.lr, | |
betas=(0.5, 0.999), | |
amsgrad=True, | |
) | |
step_optim_scheduler_g = PolynomialLR( | |
optimizer_g, | |
step_size=opt.lr_step, | |
iter_warmup=0, | |
iter_max=len(data_loader) * opt.epoch, | |
power=0.9, | |
min_lr=1e-8 | |
) | |
step_optim_scheduler_d = PolynomialLR( | |
optimizer_d, | |
step_size=opt.lr_step, | |
iter_warmup=0, | |
iter_max=len(data_loader) * opt.epoch, | |
power=0.9, | |
min_lr=1e-8 | |
) | |
########################################## DEFINE OTHERS ########################################### | |
downsampling_by2 = nn.AvgPool2d(kernel_size=2).to(local_rank) | |
# timer_handler = TimeHandler() | |
loss_handler = LossHandler() | |
############################################## TRAIN ############################################### | |
# ============= USE PRETRAINED OR NOT ============== | |
if opt.load_pretrained_model: | |
nonlocal_net.load_state_dict(load_params(os.path.join(opt.pretrained_model_dir, "nonlocal_net.pth"), | |
local_rank, | |
has_module=True)) | |
colornet.load_state_dict(load_params(os.path.join(opt.pretrained_model_dir, "colornet.pth"), | |
local_rank, | |
has_module=True)) | |
discriminator.load_state_dict(load_params(os.path.join(opt.pretrained_model_dir, "discriminator.pth"), | |
local_rank, | |
has_module=True)) | |
embed_net_params = load_params(os.path.join(opt.pretrained_model_dir, "embed_net.pth"), | |
local_rank, | |
has_module=False) | |
if "module.vit.heads_out" in embed_net_params: | |
embed_net_params.pop("module.vit.heads_out") | |
elif "vit.heads_out" in embed_net_params: | |
embed_net_params.pop("vit.heads_out") | |
embed_net.load_state_dict(embed_net_params) | |
learning_checkpoint = torch.load(os.path.join(opt.pretrained_model_dir, "learning_state.pth")) | |
optimizer_g.load_state_dict(learning_checkpoint["optimizer_g"]) | |
optimizer_d.load_state_dict(learning_checkpoint["optimizer_d"]) | |
step_optim_scheduler_g.load_state_dict(learning_checkpoint["optimizer_schedule_g"]) | |
step_optim_scheduler_d.load_state_dict(learning_checkpoint["optimizer_schedule_d"]) | |
total_iter = learning_checkpoint['total_iter'] | |
start_epoch = learning_checkpoint['epoch']+1 | |
else: | |
total_iter = 0 | |
start_epoch = 1 | |
for epoch_num in range(start_epoch, opt.epoch+1): | |
data_loader.sampler.set_epoch(epoch_num-1) | |
if is_master_process(): | |
train_progress_bar = tqdm( | |
data_loader, | |
desc =f'Epoch {epoch_num}[Training]', | |
position = 0, | |
leave = False | |
) | |
else: | |
train_progress_bar = data_loader | |
for iter, sample in enumerate(train_progress_bar): | |
# timer_handler.compute_time("load_sample") | |
total_iter += 1 | |
# =============== LOAD DATA SAMPLE ================ | |
( | |
I_last_lab, ######## (3, H, W) | |
I_current_lab, ##### (3, H, W) | |
I_reference_lab, ### (3, H, W) | |
flow_forward, ###### (2, H, W) | |
mask, ############## (1, H, W) | |
placeholder_lab, ### (3, H, W) | |
self_ref_flag, ##### (3, H, W) | |
prev_frame_path, | |
curr_frame_path, | |
ref_path, | |
) = sample | |
I_last_lab = I_last_lab.to(local_rank) | |
I_current_lab = I_current_lab.to(local_rank) | |
I_reference_lab = I_reference_lab.to(local_rank) | |
flow_forward = flow_forward.to(local_rank) | |
mask = mask.to(local_rank) | |
placeholder_lab = placeholder_lab.to(local_rank) | |
self_ref_flag = self_ref_flag.to(local_rank) | |
I_last_l = I_last_lab[:, 0:1, :, :] | |
I_last_ab = I_last_lab[:, 1:3, :, :] | |
I_current_l = I_current_lab[:, 0:1, :, :] | |
I_current_ab = I_current_lab[:, 1:3, :, :] | |
I_reference_l = I_reference_lab[:, 0:1, :, :] | |
I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)) | |
# _load_sample_time = timer_handler.compute_time("load_sample") | |
# timer_handler.compute_time("forward_model") | |
features_B = embed_net(I_reference_rgb) | |
_, B_feat_1, B_feat_2, B_feat_3 = features_B | |
# ================== COLORIZATION ================== | |
# The last frame | |
I_last_ab_predict, I_last_nonlocal_lab_predict = frame_colorization( | |
IA_l=I_last_l, | |
IB_lab=I_reference_lab, | |
IA_last_lab=placeholder_lab, | |
features_B=features_B, | |
embed_net=embed_net, | |
colornet=colornet, | |
nonlocal_net=nonlocal_net, | |
luminance_noise=opt.luminance_noise, | |
) | |
I_last_lab_predict = torch.cat((I_last_l, I_last_ab_predict), dim=1) | |
# The current frame | |
I_current_ab_predict, I_current_nonlocal_lab_predict = frame_colorization( | |
IA_l=I_current_l, | |
IB_lab=I_reference_lab, | |
IA_last_lab=I_last_lab_predict, | |
features_B=features_B, | |
embed_net=embed_net, | |
colornet=colornet, | |
nonlocal_net=nonlocal_net, | |
luminance_noise=opt.luminance_noise, | |
) | |
I_current_lab_predict = torch.cat((I_last_l, I_current_ab_predict), dim=1) | |
# ================ UPDATE GENERATOR ================ | |
if opt.weight_gan > 0: | |
optimizer_g.zero_grad() | |
optimizer_d.zero_grad() | |
fake_data_lab = torch.cat( | |
( | |
uncenter_l(I_current_l), | |
I_current_ab_predict, | |
uncenter_l(I_last_l), | |
I_last_ab_predict, | |
), | |
dim=1, | |
) | |
real_data_lab = torch.cat( | |
( | |
uncenter_l(I_current_l), | |
I_current_ab, | |
uncenter_l(I_last_l), | |
I_last_ab, | |
), | |
dim=1, | |
) | |
if opt.permute_data: | |
batch_index = torch.arange(-1, opt.batch_size - 1, dtype=torch.long) | |
real_data_lab = real_data_lab[batch_index, ...] | |
discriminator_loss = discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator) | |
discriminator_loss.backward() | |
optimizer_d.step() | |
optimizer_g.zero_grad() | |
optimizer_d.zero_grad() | |
# ================== COMPUTE LOSS ================== | |
# L1 loss | |
l1_loss = l1_loss_fn(I_current_ab, I_current_ab_predict) * opt.weigth_l1 | |
# Generator_loss. TODO: freeze this to train some first epoch | |
if epoch_num > opt.epoch_train_discriminator: | |
generator_loss = generator_loss_fn(real_data_lab, fake_data_lab, discriminator, opt.weight_gan, local_rank) | |
# Perceptual Loss | |
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1)) | |
_, pred_feat_1, pred_feat_2, pred_feat_3 = embed_net(I_predict_rgb) | |
I_current_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab), dim=1)) | |
A_feat_0, _, _, A_feat_3 = embed_net(I_current_rgb) | |
perceptual_loss = perceptual_loss_fn(A_feat_3, pred_feat_3) | |
# Contextual Loss | |
contextual_style5_1 = torch.mean(contextual_forward_loss(pred_feat_3, B_feat_3.detach())) * 8 | |
contextual_style4_1 = torch.mean(contextual_forward_loss(pred_feat_2, B_feat_2.detach())) * 4 | |
contextual_style3_1 = torch.mean(contextual_forward_loss(pred_feat_1, B_feat_1.detach())) * 2 | |
contextual_loss_total = ( | |
contextual_style5_1 + contextual_style4_1 + contextual_style3_1 | |
) * opt.weight_contextual | |
# Consistent Loss | |
consistent_loss = consistent_loss_fn( | |
I_current_lab_predict, | |
I_last_ab_predict, | |
I_current_nonlocal_lab_predict, | |
I_last_nonlocal_lab_predict, | |
flow_forward, | |
mask, | |
warping_layer, | |
weight_consistent=opt.weight_consistent, | |
weight_nonlocal_consistent=opt.weight_nonlocal_consistent, | |
device=local_rank, | |
) | |
# Smoothness loss | |
smoothness_loss = smoothness_loss_fn( | |
I_current_l, | |
I_current_lab, | |
I_current_ab_predict, | |
A_feat_0, | |
weighted_layer_color, | |
nonlocal_weighted_layer, | |
weight_smoothness=opt.weight_smoothness, | |
weight_nonlocal_smoothness=opt.weight_nonlocal_smoothness, | |
device=local_rank | |
) | |
# Total loss | |
total_loss = l1_loss + perceptual_loss + contextual_loss_total + consistent_loss + smoothness_loss | |
if epoch_num > opt.epoch_train_discriminator: | |
total_loss += generator_loss | |
# Add loss to loss handler | |
loss_handler.add_loss(key="total_loss", loss=total_loss.item()) | |
loss_handler.add_loss(key="l1_loss", loss=l1_loss.item()) | |
loss_handler.add_loss(key="perceptual_loss", loss=perceptual_loss.item()) | |
loss_handler.add_loss(key="contextual_loss", loss=contextual_loss_total.item()) | |
loss_handler.add_loss(key="consistent_loss", loss=consistent_loss.item()) | |
loss_handler.add_loss(key="smoothness_loss", loss=smoothness_loss.item()) | |
loss_handler.add_loss(key="discriminator_loss", loss=discriminator_loss.item()) | |
if epoch_num > opt.epoch_train_discriminator: | |
loss_handler.add_loss(key="generator_loss", loss=generator_loss.item()) | |
loss_handler.count_one_sample() | |
total_loss.backward() | |
optimizer_g.step() | |
step_optim_scheduler_g.step() | |
step_optim_scheduler_d.step() | |
# _forward_model_time = timer_handler.compute_time("forward_model") | |
# timer_handler.compute_time("training_logger") | |
training_logger() | |
# _training_logger_time = timer_handler.compute_time("training_logger") | |
#### | |
if is_master_process(): | |
save_checkpoints(os.path.join(opt.checkpoint_dir, f"epoch_{epoch_num}")) | |
#### | |
if opt.use_wandb: | |
wandb.finish() | |
ddp_cleanup() |