File size: 5,707 Bytes
e8861c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""by lyuwenyu
"""
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
from typing import Dict
from src.misc import dist
from src.core import BaseConfig
class BaseSolver(object):
def __init__(self, cfg: BaseConfig) -> None:
self.cfg = cfg
def setup(self, ):
'''Avoid instantiating unnecessary classes
'''
cfg = self.cfg
device = cfg.device
self.device = device
self.last_epoch = cfg.last_epoch
self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)
self.criterion = cfg.criterion.to(device)
self.postprocessor = cfg.postprocessor
# NOTE (lvwenyu): should load_tuning_state before ema instance building
if self.cfg.tuning:
print(f'Tuning checkpoint from {self.cfg.tuning}')
self.load_tuning_state(self.cfg.tuning)
self.scaler = cfg.scaler
self.ema = cfg.ema.to(device) if cfg.ema is not None else None
self.output_dir = Path(cfg.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def train(self, ):
self.setup()
self.optimizer = self.cfg.optimizer
self.lr_scheduler = self.cfg.lr_scheduler
# NOTE instantiating order
if self.cfg.resume:
print(f'Resume checkpoint from {self.cfg.resume}')
self.resume(self.cfg.resume)
self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \
shuffle=self.cfg.train_dataloader.shuffle)
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
def eval(self, ):
self.setup()
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
if self.cfg.resume:
print(f'resume from {self.cfg.resume}')
self.resume(self.cfg.resume)
def state_dict(self, last_epoch):
'''state dict
'''
state = {}
state['model'] = dist.de_parallel(self.model).state_dict()
state['date'] = datetime.now().isoformat()
# TODO
state['last_epoch'] = last_epoch
if self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if self.lr_scheduler is not None:
state['lr_scheduler'] = self.lr_scheduler.state_dict()
# state['last_epoch'] = self.lr_scheduler.last_epoch
if self.ema is not None:
state['ema'] = self.ema.state_dict()
if self.scaler is not None:
state['scaler'] = self.scaler.state_dict()
return state
def load_state_dict(self, state):
'''load state dict
'''
# TODO
if getattr(self, 'last_epoch', None) and 'last_epoch' in state:
self.last_epoch = state['last_epoch']
print('Loading last_epoch')
if getattr(self, 'model', None) and 'model' in state:
if dist.is_parallel(self.model):
self.model.module.load_state_dict(state['model'])
else:
self.model.load_state_dict(state['model'])
print('Loading model.state_dict')
if getattr(self, 'ema', None) and 'ema' in state:
self.ema.load_state_dict(state['ema'])
print('Loading ema.state_dict')
if getattr(self, 'optimizer', None) and 'optimizer' in state:
self.optimizer.load_state_dict(state['optimizer'])
print('Loading optimizer.state_dict')
if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
print('Loading lr_scheduler.state_dict')
if getattr(self, 'scaler', None) and 'scaler' in state:
self.scaler.load_state_dict(state['scaler'])
print('Loading scaler.state_dict')
def save(self, path):
'''save state
'''
state = self.state_dict()
dist.save_on_master(state, path)
def resume(self, path):
'''load resume
'''
# for cuda:0 memory
state = torch.load(path, map_location='cpu')
self.load_state_dict(state)
def load_tuning_state(self, path,):
"""only load model for tuning and skip missed/dismatched keys
"""
if 'http' in path:
state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
else:
state = torch.load(path, map_location='cpu')
module = dist.de_parallel(self.model)
# TODO hard code
if 'ema' in state:
stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
else:
stat, infos = self._matched_state(module.state_dict(), state['model'])
module.load_state_dict(stat, strict=False)
print(f'Load model.state_dict, {infos}')
@staticmethod
def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]):
missed_list = []
unmatched_list = []
matched_state = {}
for k, v in state.items():
if k in params:
if v.shape == params[k].shape:
matched_state[k] = params[k]
else:
unmatched_list.append(k)
else:
missed_list.append(k)
return matched_state, {'missed': missed_list, 'unmatched': unmatched_list}
def fit(self, ):
raise NotImplementedError('')
def val(self, ):
raise NotImplementedError('')
|