import os.path import math import argparse import time import random import numpy as np from collections import OrderedDict import logging from torch.utils.data import DataLoader import torch from utils import utils_logger from utils import utils_image as util from utils import utils_option as option from utils import utils_sisr as sisr from data.select_dataset import define_Dataset from models.select_model import define_Model ''' # -------------------------------------------- # training code for USRNet # -------------------------------------------- # Kai Zhang (cskaizhang@gmail.com) # github: https://github.com/cszn/KAIR # https://github.com/cszn/USRNet # # Reference: @inproceedings{zhang2020deep, title={Deep unfolding network for image super-resolution}, author={Zhang, Kai and Van Gool, Luc and Timofte, Radu}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, pages={3217--3226}, year={2020} } # -------------------------------------------- ''' def main(json_path='options/train_usrnet.json'): ''' # ---------------------------------------- # Step--1 (prepare opt) # ---------------------------------------- ''' parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, default=json_path, help='Path to option JSON file.') opt = option.parse(parser.parse_args().opt, is_train=True) util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key)) # ---------------------------------------- # update opt # ---------------------------------------- # -->-->-->-->-->-->-->-->-->-->-->-->-->- init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') opt['path']['pretrained_netG'] = init_path_G current_step = init_iter border = opt['scale'] # --<--<--<--<--<--<--<--<--<--<--<--<--<- # ---------------------------------------- # save opt to a '../option.json' file # ---------------------------------------- option.save(opt) # ---------------------------------------- # return None for missing key # ---------------------------------------- opt = option.dict_to_nonedict(opt) # ---------------------------------------- # configure logger # ---------------------------------------- logger_name = 'train' utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) logger = logging.getLogger(logger_name) logger.info(option.dict2str(opt)) # ---------------------------------------- # seed # ---------------------------------------- seed = opt['train']['manual_seed'] if seed is None: seed = random.randint(1, 10000) logger.info('Random seed: {}'.format(seed)) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) ''' # ---------------------------------------- # Step--2 (creat dataloader) # ---------------------------------------- ''' # ---------------------------------------- # 1) create_dataset # 2) creat_dataloader for train and test # ---------------------------------------- for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': train_set = define_Dataset(dataset_opt) train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) train_loader = DataLoader(train_set, batch_size=dataset_opt['dataloader_batch_size'], shuffle=dataset_opt['dataloader_shuffle'], num_workers=dataset_opt['dataloader_num_workers'], drop_last=True, pin_memory=True) elif phase == 'test': test_set = define_Dataset(dataset_opt) test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1, drop_last=False, pin_memory=True) else: raise NotImplementedError("Phase [%s] is not recognized." % phase) ''' # ---------------------------------------- # Step--3 (initialize model) # ---------------------------------------- ''' model = define_Model(opt) logger.info(model.info_network()) model.init_train() logger.info(model.info_params()) ''' # ---------------------------------------- # Step--4 (main training) # ---------------------------------------- ''' for epoch in range(1000000): # keep running for i, train_data in enumerate(train_loader): current_step += 1 # ------------------------------- # 1) update learning rate # ------------------------------- model.update_learning_rate(current_step) # ------------------------------- # 2) feed patch pairs # ------------------------------- model.feed_data(train_data) # ------------------------------- # 3) optimize parameters # ------------------------------- model.optimize_parameters(current_step) # ------------------------------- # 4) training information # ------------------------------- if current_step % opt['train']['checkpoint_print'] == 0: logs = model.current_log() # such as loss message = ' '.format(epoch, current_step, model.current_learning_rate()) for k, v in logs.items(): # merge log information into message message += '{:s}: {:.3e} '.format(k, v) logger.info(message) # ------------------------------- # 5) save model # ------------------------------- if current_step % opt['train']['checkpoint_save'] == 0: logger.info('Saving the model.') model.save(current_step) # ------------------------------- # 6) testing # ------------------------------- if current_step % opt['train']['checkpoint_test'] == 0: avg_psnr = 0.0 idx = 0 for test_data in test_loader: idx += 1 image_name_ext = os.path.basename(test_data['L_path'][0]) img_name, ext = os.path.splitext(image_name_ext) img_dir = os.path.join(opt['path']['images'], img_name) util.mkdir(img_dir) model.feed_data(test_data) model.test() visuals = model.current_visuals() E_img = util.tensor2uint(visuals['E']) H_img = util.tensor2uint(visuals['H']) # ----------------------- # save estimated image E # ----------------------- save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) util.imsave(E_img, save_img_path) # ----------------------- # calculate PSNR # ----------------------- current_psnr = util.calculate_psnr(E_img, H_img, border=border) logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) avg_psnr += current_psnr avg_psnr = avg_psnr / idx # testing log logger.info('