"""by lyuwenyu """ import torch import torch.nn as nn from datetime import datetime from pathlib import Path from typing import Dict from src.misc import dist from src.core import BaseConfig class BaseSolver(object): def __init__(self, cfg: BaseConfig) -> None: self.cfg = cfg def setup(self, ): '''Avoid instantiating unnecessary classes ''' cfg = self.cfg device = cfg.device self.device = device self.last_epoch = cfg.last_epoch self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn) self.criterion = cfg.criterion.to(device) self.postprocessor = cfg.postprocessor # NOTE (lvwenyu): should load_tuning_state before ema instance building if self.cfg.tuning: print(f'Tuning checkpoint from {self.cfg.tuning}') self.load_tuning_state(self.cfg.tuning) self.scaler = cfg.scaler self.ema = cfg.ema.to(device) if cfg.ema is not None else None self.output_dir = Path(cfg.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) def train(self, ): self.setup() self.optimizer = self.cfg.optimizer self.lr_scheduler = self.cfg.lr_scheduler # NOTE instantiating order if self.cfg.resume: print(f'Resume checkpoint from {self.cfg.resume}') self.resume(self.cfg.resume) self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \ shuffle=self.cfg.train_dataloader.shuffle) self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \ shuffle=self.cfg.val_dataloader.shuffle) def eval(self, ): self.setup() self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \ shuffle=self.cfg.val_dataloader.shuffle) if self.cfg.resume: print(f'resume from {self.cfg.resume}') self.resume(self.cfg.resume) def state_dict(self, last_epoch): '''state dict ''' state = {} state['model'] = dist.de_parallel(self.model).state_dict() state['date'] = datetime.now().isoformat() # TODO state['last_epoch'] = last_epoch if self.optimizer is not None: state['optimizer'] = self.optimizer.state_dict() if self.lr_scheduler is not None: state['lr_scheduler'] = self.lr_scheduler.state_dict() # state['last_epoch'] = self.lr_scheduler.last_epoch if self.ema is not None: state['ema'] = self.ema.state_dict() if self.scaler is not None: state['scaler'] = self.scaler.state_dict() return state def load_state_dict(self, state): '''load state dict ''' # TODO if getattr(self, 'last_epoch', None) and 'last_epoch' in state: self.last_epoch = state['last_epoch'] print('Loading last_epoch') if getattr(self, 'model', None) and 'model' in state: if dist.is_parallel(self.model): self.model.module.load_state_dict(state['model']) else: self.model.load_state_dict(state['model']) print('Loading model.state_dict') if getattr(self, 'ema', None) and 'ema' in state: self.ema.load_state_dict(state['ema']) print('Loading ema.state_dict') if getattr(self, 'optimizer', None) and 'optimizer' in state: self.optimizer.load_state_dict(state['optimizer']) print('Loading optimizer.state_dict') if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state: self.lr_scheduler.load_state_dict(state['lr_scheduler']) print('Loading lr_scheduler.state_dict') if getattr(self, 'scaler', None) and 'scaler' in state: self.scaler.load_state_dict(state['scaler']) print('Loading scaler.state_dict') def save(self, path): '''save state ''' state = self.state_dict() dist.save_on_master(state, path) def resume(self, path): '''load resume ''' # for cuda:0 memory state = torch.load(path, map_location='cpu') self.load_state_dict(state) def load_tuning_state(self, path,): """only load model for tuning and skip missed/dismatched keys """ if 'http' in path: state = torch.hub.load_state_dict_from_url(path, map_location='cpu') else: state = torch.load(path, map_location='cpu') module = dist.de_parallel(self.model) # TODO hard code if 'ema' in state: stat, infos = self._matched_state(module.state_dict(), state['ema']['module']) else: stat, infos = self._matched_state(module.state_dict(), state['model']) module.load_state_dict(stat, strict=False) print(f'Load model.state_dict, {infos}') @staticmethod def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]): missed_list = [] unmatched_list = [] matched_state = {} for k, v in state.items(): if k in params: if v.shape == params[k].shape: matched_state[k] = params[k] else: unmatched_list.append(k) else: missed_list.append(k) return matched_state, {'missed': missed_list, 'unmatched': unmatched_list} def fit(self, ): raise NotImplementedError('') def val(self, ): raise NotImplementedError('')