TextureScraping / libs /train_base.py
sunshineatnoon
Add application file
1b2a9b1
"""Training base class
"""
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.fft
import torch
import numpy as np
import argparse
import wandb
import math
import time
import os
from . import flow_transforms
class TrainerBase():
def __init__(self, args):
"""
Initialization function.
"""
cudnn.benchmark = True
os.environ['WANDB_DIR'] = args.work_dir
args.use_wandb = (args.use_wandb == 1)
if args.use_wandb:
wandb.login(key="d56eb81cd6396f0a181524ba214f488cf281e76b")
wandb.init(project=args.project_name, name=args.exp_name)
wandb.config.update(args)
self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).cuda()
self.color_palette = np.loadtxt('data/palette.txt',dtype=np.uint8).reshape(-1,3)
self.args = args
def init_dataset(self):
"""
Initialize dataset
"""
if self.args.dataset == 'BSD500':
from ..data import BSD500
# ========== Data loading code ==============
input_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
])
val_input_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
])
target_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
])
co_transform = flow_transforms.Compose([
flow_transforms.RandomCrop((self.args.train_img_height , self.args.train_img_width)),
flow_transforms.RandomVerticalFlip(),
flow_transforms.RandomHorizontalFlip()
])
print("=> loading img pairs from '{}'".format(self.args.data))
train_set, val_set = BSD500(self.args.data,
transform=input_transform,
val_transform = val_input_transform,
target_transform=target_transform,
co_transform=co_transform,
bi_filter=True)
print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set)))
self.train_loader = torch.utils.data.DataLoader(
train_set, batch_size=self.args.batch_size,
num_workers=self.args.workers, pin_memory=True, shuffle=True, drop_last=True)
elif self.args.dataset == 'texture':
from ..data.texture_v3 import Dataset
dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height)
self.train_loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = self.args.batch_size,
shuffle = True,
num_workers = self.args.workers,
drop_last = True)
elif self.args.dataset == 'DIV2K':
from basicsr.data import create_dataloader, create_dataset
opt = {}
opt['dist'] = False
opt['phase'] = 'train'
opt['name'] = 'DIV2K'
opt['type'] = 'PairedImageDataset'
opt['dataroot_gt'] = self.args.HR_dir
opt['dataroot_lq'] = self.args.LR_dir
opt['filename_tmpl'] = '{}'
opt['io_backend'] = dict(type='disk')
opt['gt_size'] = self.args.train_img_height
opt['use_flip'] = True
opt['use_rot'] = True
opt['use_shuffle'] = True
opt['num_worker_per_gpu'] = self.args.workers
opt['batch_size_per_gpu'] = self.args.batch_size
opt['scale'] = int(self.args.ratio)
opt['dataset_enlarge_ratio'] = 1
dataset = create_dataset(opt)
self.train_loader = create_dataloader(
dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None)
else:
raise ValueError("Unknown dataset: {}.".format(self.args.dataset))
def init_training(self):
self.init_constant()
self.init_dataset()
self.define_model()
self.define_criterion()
self.define_optimizer()
def adjust_learning_rate(self, iteration):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = self.args.lr * (0.95 ** (iteration // self.args.lr_decay_freq))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def logging(self, iteration, epoch):
print_str = "[{}/{}][{}/{}], ".format(iteration, len(self.train_loader), epoch, self.args.nepochs)
for k,v in self.losses.items():
print_str += "{}: {:4f} ".format(k, v)
print_str += "time: {:2f}.".format(self.iter_time)
print(print_str)
def get_sp_grid(self, H, W, G, R = 1):
W = int(W // R)
H = int(H // R)
if G > min(H, W):
raise ValueError('Grid size must be smaller than image size!')
grid = torch.from_numpy(np.arange(G**2)).view(1, 1, G, G)
grid = torch.cat([grid]*(int(math.ceil(W/G))), dim = -1)
grid = torch.cat([grid]*(int(math.ceil(H/G))), dim = -2)
grid = grid[:, :, :H, :W]
return grid.float()
def save_network(self, name = None):
cpk = {}
cpk['epoch'] = self.epoch
cpk['lr'] = self.optimizer.param_groups[0]['lr']
if hasattr(self.model, 'module'):
cpk['model'] = self.model.module.cpu().state_dict()
else:
cpk['model'] = self.model.cpu().state_dict()
if name is None:
out_path = os.path.join(self.args.out_dir, "cpk.pth")
else:
out_path = os.path.join(self.args.out_dir, name + ".pth")
torch.save(cpk, out_path)
self.model.cuda()
return
def init_constant(self):
return
def define_model(self):
raise NotImplementedError
def define_criterion(self):
raise NotImplementedError
def define_optimizer(self):
raise NotImplementedError
def display(self):
raise NotImplementedError
def forward(self):
raise NotImplementedError
def train(self):
args = self.args
total_iteration = 0
for epoch in range(args.nepochs):
self.epoch = epoch
for iteration, data in enumerate(self.train_loader):
if args.dataset == 'BSD500':
image = data[0].cuda()
self.label = data[1].cuda()
elif args.dataset == 'texture':
image = data[0].cuda()
self.image2 = data[1].cuda()
else:
image = data['lq'].cuda()
self.gt = data['gt'].cuda()
start_time = time.time()
total_iteration += 1
self.optimizer.zero_grad()
image = image.cuda()
if args.dataset == 'BSD500':
self.image = image + self.mean_values
self.gt = self.image
else:
self.image = image
self.forward()
total_loss = 0
for k,v in self.losses.items():
if hasattr(args, '{}_wt'.format(k)):
total_loss += v * getattr(args, '{}_wt'.format(k))
else:
total_loss += v
total_loss.backward()
self.optimizer.step()
end_time = time.time()
self.iter_time = end_time - start_time
self.adjust_learning_rate(total_iteration)
if((iteration + 1) % args.log_freq == 0):
self.logging(iteration, epoch)
if args.use_wandb:
wandb.log(self.losses)
if(iteration % args.display_freq == 0):
example_images = self.display()
if args.use_wandb:
wandb.log({'images': example_images})
if((epoch + 1) % args.save_freq == 0):
self.save_network()