Spaces:
Runtime error
Runtime error
"""Training base class | |
""" | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
import torch.backends.cudnn as cudnn | |
import torch.nn.functional as F | |
import torch.fft | |
import torch | |
import numpy as np | |
import argparse | |
import wandb | |
import math | |
import time | |
import os | |
from . import flow_transforms | |
class TrainerBase(): | |
def __init__(self, args): | |
""" | |
Initialization function. | |
""" | |
cudnn.benchmark = True | |
os.environ['WANDB_DIR'] = args.work_dir | |
args.use_wandb = (args.use_wandb == 1) | |
if args.use_wandb: | |
wandb.login(key="d56eb81cd6396f0a181524ba214f488cf281e76b") | |
wandb.init(project=args.project_name, name=args.exp_name) | |
wandb.config.update(args) | |
self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).cuda() | |
self.color_palette = np.loadtxt('data/palette.txt',dtype=np.uint8).reshape(-1,3) | |
self.args = args | |
def init_dataset(self): | |
""" | |
Initialize dataset | |
""" | |
if self.args.dataset == 'BSD500': | |
from ..data import BSD500 | |
# ========== Data loading code ============== | |
input_transform = transforms.Compose([ | |
flow_transforms.ArrayToTensor(), | |
transforms.Normalize(mean=[0,0,0], std=[255,255,255]), | |
transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) | |
]) | |
val_input_transform = transforms.Compose([ | |
flow_transforms.ArrayToTensor(), | |
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), | |
transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]) | |
]) | |
target_transform = transforms.Compose([ | |
flow_transforms.ArrayToTensor(), | |
]) | |
co_transform = flow_transforms.Compose([ | |
flow_transforms.RandomCrop((self.args.train_img_height , self.args.train_img_width)), | |
flow_transforms.RandomVerticalFlip(), | |
flow_transforms.RandomHorizontalFlip() | |
]) | |
print("=> loading img pairs from '{}'".format(self.args.data)) | |
train_set, val_set = BSD500(self.args.data, | |
transform=input_transform, | |
val_transform = val_input_transform, | |
target_transform=target_transform, | |
co_transform=co_transform, | |
bi_filter=True) | |
print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set))) | |
self.train_loader = torch.utils.data.DataLoader( | |
train_set, batch_size=self.args.batch_size, | |
num_workers=self.args.workers, pin_memory=True, shuffle=True, drop_last=True) | |
elif self.args.dataset == 'texture': | |
from ..data.texture_v3 import Dataset | |
dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height) | |
self.train_loader = torch.utils.data.DataLoader(dataset = dataset, | |
batch_size = self.args.batch_size, | |
shuffle = True, | |
num_workers = self.args.workers, | |
drop_last = True) | |
elif self.args.dataset == 'DIV2K': | |
from basicsr.data import create_dataloader, create_dataset | |
opt = {} | |
opt['dist'] = False | |
opt['phase'] = 'train' | |
opt['name'] = 'DIV2K' | |
opt['type'] = 'PairedImageDataset' | |
opt['dataroot_gt'] = self.args.HR_dir | |
opt['dataroot_lq'] = self.args.LR_dir | |
opt['filename_tmpl'] = '{}' | |
opt['io_backend'] = dict(type='disk') | |
opt['gt_size'] = self.args.train_img_height | |
opt['use_flip'] = True | |
opt['use_rot'] = True | |
opt['use_shuffle'] = True | |
opt['num_worker_per_gpu'] = self.args.workers | |
opt['batch_size_per_gpu'] = self.args.batch_size | |
opt['scale'] = int(self.args.ratio) | |
opt['dataset_enlarge_ratio'] = 1 | |
dataset = create_dataset(opt) | |
self.train_loader = create_dataloader( | |
dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None) | |
else: | |
raise ValueError("Unknown dataset: {}.".format(self.args.dataset)) | |
def init_training(self): | |
self.init_constant() | |
self.init_dataset() | |
self.define_model() | |
self.define_criterion() | |
self.define_optimizer() | |
def adjust_learning_rate(self, iteration): | |
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" | |
lr = self.args.lr * (0.95 ** (iteration // self.args.lr_decay_freq)) | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = lr | |
def logging(self, iteration, epoch): | |
print_str = "[{}/{}][{}/{}], ".format(iteration, len(self.train_loader), epoch, self.args.nepochs) | |
for k,v in self.losses.items(): | |
print_str += "{}: {:4f} ".format(k, v) | |
print_str += "time: {:2f}.".format(self.iter_time) | |
print(print_str) | |
def get_sp_grid(self, H, W, G, R = 1): | |
W = int(W // R) | |
H = int(H // R) | |
if G > min(H, W): | |
raise ValueError('Grid size must be smaller than image size!') | |
grid = torch.from_numpy(np.arange(G**2)).view(1, 1, G, G) | |
grid = torch.cat([grid]*(int(math.ceil(W/G))), dim = -1) | |
grid = torch.cat([grid]*(int(math.ceil(H/G))), dim = -2) | |
grid = grid[:, :, :H, :W] | |
return grid.float() | |
def save_network(self, name = None): | |
cpk = {} | |
cpk['epoch'] = self.epoch | |
cpk['lr'] = self.optimizer.param_groups[0]['lr'] | |
if hasattr(self.model, 'module'): | |
cpk['model'] = self.model.module.cpu().state_dict() | |
else: | |
cpk['model'] = self.model.cpu().state_dict() | |
if name is None: | |
out_path = os.path.join(self.args.out_dir, "cpk.pth") | |
else: | |
out_path = os.path.join(self.args.out_dir, name + ".pth") | |
torch.save(cpk, out_path) | |
self.model.cuda() | |
return | |
def init_constant(self): | |
return | |
def define_model(self): | |
raise NotImplementedError | |
def define_criterion(self): | |
raise NotImplementedError | |
def define_optimizer(self): | |
raise NotImplementedError | |
def display(self): | |
raise NotImplementedError | |
def forward(self): | |
raise NotImplementedError | |
def train(self): | |
args = self.args | |
total_iteration = 0 | |
for epoch in range(args.nepochs): | |
self.epoch = epoch | |
for iteration, data in enumerate(self.train_loader): | |
if args.dataset == 'BSD500': | |
image = data[0].cuda() | |
self.label = data[1].cuda() | |
elif args.dataset == 'texture': | |
image = data[0].cuda() | |
self.image2 = data[1].cuda() | |
else: | |
image = data['lq'].cuda() | |
self.gt = data['gt'].cuda() | |
start_time = time.time() | |
total_iteration += 1 | |
self.optimizer.zero_grad() | |
image = image.cuda() | |
if args.dataset == 'BSD500': | |
self.image = image + self.mean_values | |
self.gt = self.image | |
else: | |
self.image = image | |
self.forward() | |
total_loss = 0 | |
for k,v in self.losses.items(): | |
if hasattr(args, '{}_wt'.format(k)): | |
total_loss += v * getattr(args, '{}_wt'.format(k)) | |
else: | |
total_loss += v | |
total_loss.backward() | |
self.optimizer.step() | |
end_time = time.time() | |
self.iter_time = end_time - start_time | |
self.adjust_learning_rate(total_iteration) | |
if((iteration + 1) % args.log_freq == 0): | |
self.logging(iteration, epoch) | |
if args.use_wandb: | |
wandb.log(self.losses) | |
if(iteration % args.display_freq == 0): | |
example_images = self.display() | |
if args.use_wandb: | |
wandb.log({'images': example_images}) | |
if((epoch + 1) % args.save_freq == 0): | |
self.save_network() | |