cheng-hust's picture
Upload 91 files
e8861c0 verified
raw
history blame
5.71 kB
"""by lyuwenyu
"""
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
from typing import Dict
from src.misc import dist
from src.core import BaseConfig
class BaseSolver(object):
def __init__(self, cfg: BaseConfig) -> None:
self.cfg = cfg
def setup(self, ):
'''Avoid instantiating unnecessary classes
'''
cfg = self.cfg
device = cfg.device
self.device = device
self.last_epoch = cfg.last_epoch
self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)
self.criterion = cfg.criterion.to(device)
self.postprocessor = cfg.postprocessor
# NOTE (lvwenyu): should load_tuning_state before ema instance building
if self.cfg.tuning:
print(f'Tuning checkpoint from {self.cfg.tuning}')
self.load_tuning_state(self.cfg.tuning)
self.scaler = cfg.scaler
self.ema = cfg.ema.to(device) if cfg.ema is not None else None
self.output_dir = Path(cfg.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def train(self, ):
self.setup()
self.optimizer = self.cfg.optimizer
self.lr_scheduler = self.cfg.lr_scheduler
# NOTE instantiating order
if self.cfg.resume:
print(f'Resume checkpoint from {self.cfg.resume}')
self.resume(self.cfg.resume)
self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \
shuffle=self.cfg.train_dataloader.shuffle)
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
def eval(self, ):
self.setup()
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
if self.cfg.resume:
print(f'resume from {self.cfg.resume}')
self.resume(self.cfg.resume)
def state_dict(self, last_epoch):
'''state dict
'''
state = {}
state['model'] = dist.de_parallel(self.model).state_dict()
state['date'] = datetime.now().isoformat()
# TODO
state['last_epoch'] = last_epoch
if self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if self.lr_scheduler is not None:
state['lr_scheduler'] = self.lr_scheduler.state_dict()
# state['last_epoch'] = self.lr_scheduler.last_epoch
if self.ema is not None:
state['ema'] = self.ema.state_dict()
if self.scaler is not None:
state['scaler'] = self.scaler.state_dict()
return state
def load_state_dict(self, state):
'''load state dict
'''
# TODO
if getattr(self, 'last_epoch', None) and 'last_epoch' in state:
self.last_epoch = state['last_epoch']
print('Loading last_epoch')
if getattr(self, 'model', None) and 'model' in state:
if dist.is_parallel(self.model):
self.model.module.load_state_dict(state['model'])
else:
self.model.load_state_dict(state['model'])
print('Loading model.state_dict')
if getattr(self, 'ema', None) and 'ema' in state:
self.ema.load_state_dict(state['ema'])
print('Loading ema.state_dict')
if getattr(self, 'optimizer', None) and 'optimizer' in state:
self.optimizer.load_state_dict(state['optimizer'])
print('Loading optimizer.state_dict')
if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
print('Loading lr_scheduler.state_dict')
if getattr(self, 'scaler', None) and 'scaler' in state:
self.scaler.load_state_dict(state['scaler'])
print('Loading scaler.state_dict')
def save(self, path):
'''save state
'''
state = self.state_dict()
dist.save_on_master(state, path)
def resume(self, path):
'''load resume
'''
# for cuda:0 memory
state = torch.load(path, map_location='cpu')
self.load_state_dict(state)
def load_tuning_state(self, path,):
"""only load model for tuning and skip missed/dismatched keys
"""
if 'http' in path:
state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
else:
state = torch.load(path, map_location='cpu')
module = dist.de_parallel(self.model)
# TODO hard code
if 'ema' in state:
stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
else:
stat, infos = self._matched_state(module.state_dict(), state['model'])
module.load_state_dict(stat, strict=False)
print(f'Load model.state_dict, {infos}')
@staticmethod
def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]):
missed_list = []
unmatched_list = []
matched_state = {}
for k, v in state.items():
if k in params:
if v.shape == params[k].shape:
matched_state[k] = params[k]
else:
unmatched_list.append(k)
else:
missed_list.append(k)
return matched_state, {'missed': missed_list, 'unmatched': unmatched_list}
def fit(self, ):
raise NotImplementedError('')
def val(self, ):
raise NotImplementedError('')