import time import wandb import logging import numpy as np import os.path as osp from collections import OrderedDict import torch from torch.optim import AdamW from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from .logger import CustomLogger from utils.utils import AverageMeterGroups from metrics.psnr_ssim import calculate_psnr from utils.build_utils import build_from_cfg class Trainer: def __init__(self, config): super().__init__() self.config = config self.rank = self.config['local_rank'] init_log = self._init_logger() self._init_dataset() self._init_loss() self.model_name = config['exp_name'] self.model = build_from_cfg(config.network).to(self.config.device) if config['distributed']: self.model = DDP(self.model, device_ids=[self.rank], output_device=self.rank, broadcast_buffers=True, find_unused_parameters=False) init_log += str(self.model) self.optimizer = AdamW(self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay) if self.rank == 0: print(init_log) self.logger(init_log) self.resume_training() def resume_training(self): ckpt_path = self.config.get('resume_state') if ckpt_path is not None: ckpt = torch.load(self.config['resume_state']) if self.config['distributed']: self.model.module.load_state_dict(ckpt['state_dict']) else: self.model.load_state_dict(ckpt['state_dict']) self.optimizer.load_state_dict(ckpt['optim']) self.resume_epoch = ckpt.get('epoch') self.logger( f'load model from {ckpt_path} and training resumes from epoch {self.resume_epoch}') else: self.resume_epoch = 0 def _init_logger(self): init_log = '' console_cfg = dict( level=logging.INFO, format="%(asctime)s %(filename)s[line:%(lineno)d]" "%(levelname)s %(message)s", datefmt="%a, %d %b %Y %H:%M:%S", filename=f"{self.config['save_dir']}/log", filemode='w') tb_cfg = dict(log_dir=osp.join(self.config['save_dir'], 'tb_logger')) wandb_cfg = None use_wandb = self.config['logger'].get('use_wandb', False) if use_wandb: resume_id = self.config['logger'].get('resume_id', None) if resume_id: wandb_id = resume_id resume = 'allow' init_log += f'Resume wandb logger with id={wandb_id}.' else: wandb_id = wandb.util.generate_id() resume = 'never' wandb_cfg = dict(id=wandb_id, resume=resume, name=osp.basename(self.config['save_dir']), config=self.config, project="YOUR PROJECT", entity="YOUR ENTITY", sync_tensorboard=True) init_log += f'Use wandb logger with id={wandb_id}; project=[YOUR PROJECT].' self.logger = CustomLogger(console_cfg, tb_cfg, wandb_cfg, self.rank) return init_log def _init_dataset(self): dataset_train = build_from_cfg(self.config.data.train) dataset_val = build_from_cfg(self.config.data.val) self.sampler = DistributedSampler( dataset_train, num_replicas=self.config['world_size'], rank=self.config['local_rank']) self.config.data.train_loader.batch_size //= self.config['world_size'] self.loader_train = DataLoader(dataset_train, **self.config.data.train_loader, pin_memory=True, drop_last=True, sampler=self.sampler) self.loader_val = DataLoader(dataset_val, **self.config.data.val_loader, pin_memory=True, shuffle=False, drop_last=False) def _init_loss(self): self.loss_dict = dict() for loss_cfg in self.config.losses: loss = build_from_cfg(loss_cfg) self.loss_dict[loss_cfg['nickname']] = loss def set_lr(self, optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr def get_lr(self, iters): ratio = 0.5 * (1.0 + np.cos(iters / (self.config['epochs'] * self.loader_train.__len__()) * np.pi)) lr = (self.config['lr'] - self.config['lr_min'] ) * ratio + self.config['lr_min'] return lr def train(self): local_rank = self.config['local_rank'] best_psnr = 0.0 loss_group = AverageMeterGroups() time_group = AverageMeterGroups() iters_per_epoch = self.loader_train.__len__() iters = self.resume_epoch * iters_per_epoch total_iters = self.config['epochs'] * iters_per_epoch start_t = time.time() total_t = 0 for epoch in range(self.resume_epoch, self.config['epochs']): self.sampler.set_epoch(epoch) for data in self.loader_train: for k, v in data.items(): data[k] = v.to(self.config['device']) data_t = time.time() - start_t lr = self.get_lr(iters) self.set_lr(self.optimizer, lr) self.optimizer.zero_grad() results = self.model(**data) total_loss = torch.tensor(0., device=self.config['device']) for name, loss in self.loss_dict.items(): l = loss(**results, **data) loss_group.update({name: l.cpu().data}) total_loss += l total_loss.backward() self.optimizer.step() iters += 1 iter_t = time.time() - start_t total_t += iter_t time_group.update({'data_t': data_t, 'iter_t': iter_t}) if (iters+1) % 100 == 0 and local_rank == 0: tpi = total_t / (iters - self.resume_epoch * iters_per_epoch) eta = total_iters * tpi remainder = (total_iters - iters) * tpi eta = self.eta_format(eta) remainder = self.eta_format(remainder) log_str = f"[{self.model_name}]epoch:{epoch +1}/{self.config['epochs']} " log_str += f"iter:{iters + 1}/{self.config['epochs'] * iters_per_epoch} " log_str += f"time:{time_group.avg('iter_t'):.3f}({time_group.avg('data_t'):.3f}) " log_str += f"lr:{lr:.3e} eta:{remainder}({eta})\n" for name in self.loss_dict.keys(): avg_l = loss_group.avg(name) log_str += f"{name}:{avg_l:.3e} " self.logger(tb_msg=[f'loss/{name}', avg_l, iters]) log_str += f'best:{best_psnr:.2f}dB\n\n' self.logger(log_str) loss_group.reset() time_group.reset() start_t = time.time() if (epoch+1) % self.config['eval_interval'] == 0 and local_rank == 0: psnr, eval_t = self.evaluate(epoch) total_t += eval_t self.logger(tb_msg=['eval/psnr', psnr, epoch]) if psnr > best_psnr: best_psnr = psnr self.save('psnr_best.pth', epoch) if self.logger.enable_wandb: wandb.run.summary["best_psnr"] = best_psnr if (epoch+1) % 50 == 0: self.save(f'epoch_{epoch+1}.pth', epoch) self.save('latest.pth', epoch) self.logger.close() def evaluate(self, epoch): psnr_list = [] time_stamp = time.time() for i, data in enumerate(self.loader_val): for k, v in data.items(): data[k] = v.to(self.config['device']) with torch.no_grad(): results = self.model(**data, eval=True) imgt_pred = results['imgt_pred'] for j in range(data['img0'].shape[0]): psnr = calculate_psnr(imgt_pred[j].detach().unsqueeze( 0), data['imgt'][j].unsqueeze(0)).cpu().data psnr_list.append(psnr) eval_time = time.time() - time_stamp self.logger('eval epoch:{}/{} time:{:.2f} psnr:{:.3f}'.format( epoch+1, self.config["epochs"], eval_time, np.array(psnr_list).mean())) return np.array(psnr_list).mean(), eval_time def save(self, name, epoch): save_path = '{}/{}/{}'.format(self.config['save_dir'], 'ckpts', name) ckpt = OrderedDict(epoch=epoch) if self.config['distributed']: ckpt['state_dict'] = self.model.module.state_dict() else: ckpt['state_dict'] = self.model.state_dict() ckpt['optim'] = self.optimizer.state_dict() torch.save(ckpt, save_path) def eta_format(self, eta): time_str = '' if eta >= 3600: hours = int(eta // 3600) eta -= hours * 3600 time_str = f'{hours}' if eta >= 60: mins = int(eta // 60) eta -= mins * 60 time_str = f'{time_str}:{mins:02}' eta = int(eta) time_str = f'{time_str}:{eta:02}' return time_str