|
|
|
"""Contains the base class for runner. |
|
|
|
This runner can be used for both training and inference with multi-threads. |
|
""" |
|
|
|
import os |
|
import json |
|
from copy import deepcopy |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
from datasets import BaseDataset |
|
from datasets import IterDataLoader |
|
from models import build_model |
|
from . import controllers |
|
from . import losses |
|
from . import misc |
|
from .optimizer import build_optimizers |
|
from .running_stats import RunningStats |
|
|
|
|
|
def _strip_state_dict_prefix(state_dict, prefix='module.'): |
|
"""Removes the name prefix in checkpoint. |
|
|
|
Basically, when the model is deployed in parallel, the prefix `module.` will |
|
be added to the saved checkpoint. This function is used to remove the |
|
prefix, which is friendly to checkpoint loading. |
|
|
|
Args: |
|
state_dict: The state dict where the variable names are processed. |
|
prefix: The prefix to remove. (default: `module.`) |
|
""" |
|
if not all(key.startswith(prefix) for key in state_dict.keys()): |
|
return state_dict |
|
|
|
stripped_state_dict = dict() |
|
for key in state_dict: |
|
stripped_state_dict[key.replace(prefix, '')] = state_dict[key] |
|
return stripped_state_dict |
|
|
|
|
|
class BaseRunner(object): |
|
"""Defines the base runner class.""" |
|
|
|
def __init__(self, config, logger): |
|
self._name = self.__class__.__name__ |
|
self._config = deepcopy(config) |
|
self.logger = logger |
|
self.work_dir = self.config.work_dir |
|
os.makedirs(self.work_dir, exist_ok=True) |
|
|
|
self.logger.info('Running Configuration:') |
|
config_str = json.dumps(self.config, indent=4).replace('"', '\'') |
|
self.logger.print(config_str + '\n') |
|
with open(os.path.join(self.work_dir, 'config.json'), 'w') as f: |
|
json.dump(self.config, f, indent=4) |
|
self._rank = dist.get_rank() |
|
self._world_size = dist.get_world_size() |
|
|
|
self.batch_size = self.config.batch_size |
|
self.val_batch_size = self.config.get('val_batch_size', self.batch_size) |
|
self._iter = 0 |
|
self._start_iter = 0 |
|
self.seen_img = 0 |
|
self.total_iters = self.config.get('total_iters', 0) |
|
if self.total_iters == 0 and self.config.get('total_img', 0) > 0: |
|
total_image = self.config.get('total_img') |
|
total_batch = self.world_size * self.batch_size |
|
self.total_iters = int(total_image / total_batch + 0.5) |
|
|
|
self.mode = None |
|
self.train_loader = None |
|
self.val_loader = None |
|
|
|
self.models = dict() |
|
self.optimizers = dict() |
|
self.lr_schedulers = dict() |
|
self.controllers = [] |
|
self.loss = None |
|
|
|
self.running_stats = RunningStats() |
|
self.start_time = 0 |
|
self.end_time = 0 |
|
self.timer = controllers.Timer() |
|
self.timer.start(self) |
|
|
|
self.build_models() |
|
self.build_controllers() |
|
|
|
def finish(self): |
|
"""Finishes runner by ending controllers and timer.""" |
|
for controller in self.controllers: |
|
controller.end(self) |
|
self.timer.end(self) |
|
self.logger.info(f'Finish runner in ' |
|
f'{misc.format_time(self.end_time - self.start_time)}') |
|
|
|
@property |
|
def name(self): |
|
"""Returns the name of the runner.""" |
|
return self._name |
|
|
|
@property |
|
def config(self): |
|
"""Returns the configuration of the runner.""" |
|
return self._config |
|
|
|
@property |
|
def rank(self): |
|
"""Returns the rank of the current runner.""" |
|
return self._rank |
|
|
|
@property |
|
def world_size(self): |
|
"""Returns the world size.""" |
|
return self._world_size |
|
|
|
@property |
|
def iter(self): |
|
"""Returns the current iteration.""" |
|
return self._iter |
|
|
|
@property |
|
def start_iter(self): |
|
"""Returns the start iteration.""" |
|
return self._start_iter |
|
|
|
def convert_epoch_to_iter(self, epoch): |
|
"""Converts number of epochs to number of iterations.""" |
|
return int(epoch * len(self.train_loader) + 0.5) |
|
|
|
def build_dataset(self, mode): |
|
"""Builds train/val dataset.""" |
|
if not hasattr(self.config, 'data'): |
|
return |
|
assert isinstance(mode, str) |
|
mode = mode.lower() |
|
self.logger.info(f'Building `{mode}` dataset ...') |
|
if mode not in ['train', 'val']: |
|
raise ValueError(f'Invalid dataset mode `{mode}`!') |
|
dataset = BaseDataset(**self.config.data[mode]) |
|
if mode == 'train': |
|
self.train_loader = IterDataLoader( |
|
dataset=dataset, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
num_workers=self.config.data.get('num_workers', 2), |
|
current_iter=self.iter, |
|
repeat=self.config.data.get('repeat', 1)) |
|
elif mode == 'val': |
|
self.val_loader = IterDataLoader( |
|
dataset=dataset, |
|
batch_size=self.val_batch_size, |
|
shuffle=False, |
|
num_workers=self.config.data.get('num_workers', 2), |
|
current_iter=0, |
|
repeat=1) |
|
else: |
|
raise NotImplementedError(f'Not implemented dataset mode `{mode}`!') |
|
self.logger.info(f'Finish building `{mode}` dataset.') |
|
|
|
def build_models(self): |
|
"""Builds models, optimizers, and learning rate schedulers.""" |
|
self.logger.info(f'Building models ...') |
|
lr_config = dict() |
|
opt_config = dict() |
|
for module, module_config in self.config.modules.items(): |
|
model_config = module_config['model'] |
|
self.models[module] = build_model(module=module, **model_config) |
|
self.models[module].cuda() |
|
opt_config[module] = module_config.get('opt', None) |
|
lr_config[module] = module_config.get('lr', None) |
|
build_optimizers(opt_config, self) |
|
self.controllers.append(controllers.LRScheduler(lr_config)) |
|
self.logger.info(f'Finish building models.') |
|
|
|
model_info = 'Model structures:\n' |
|
model_info += '==============================================\n' |
|
for module in self.models: |
|
model_info += f'{module}\n' |
|
model_info += '----------------------------------------------\n' |
|
model_info += str(self.models[module]) |
|
model_info += '\n' |
|
model_info += "==============================================\n" |
|
self.logger.info(model_info) |
|
|
|
def distribute(self): |
|
"""Sets `self.model` as `torch.nn.parallel.DistributedDataParallel`.""" |
|
for name in self.models: |
|
self.models[name] = torch.nn.parallel.DistributedDataParallel( |
|
module=self.models[name], |
|
device_ids=[torch.cuda.current_device()], |
|
broadcast_buffers=False, |
|
find_unused_parameters=True) |
|
|
|
@staticmethod |
|
def get_module(model): |
|
"""Handles distributed model.""" |
|
if hasattr(model, 'module'): |
|
return model.module |
|
return model |
|
|
|
def build_controllers(self): |
|
"""Builds additional controllers besides LRScheduler.""" |
|
if not hasattr(self.config, 'controllers'): |
|
return |
|
self.logger.info(f'Building controllers ...') |
|
for key, ctrl_config in self.config.controllers.items(): |
|
self.controllers.append(getattr(controllers, key)(ctrl_config)) |
|
self.controllers.sort(key=lambda x: x.priority) |
|
for controller in self.controllers: |
|
controller.start(self) |
|
self.logger.info(f'Finish building controllers.') |
|
|
|
def build_loss(self): |
|
"""Builds loss functions.""" |
|
if not hasattr(self.config, 'loss'): |
|
return |
|
self.logger.info(f'Building loss function ...') |
|
loss_config = deepcopy(self.config.loss) |
|
loss_type = loss_config.pop('type') |
|
self.loss = getattr(losses, loss_type)(self, **loss_config) |
|
self.logger.info(f'Finish building loss function.') |
|
|
|
def pre_execute_controllers(self): |
|
"""Pre-executes all controllers in order of priority.""" |
|
for controller in self.controllers: |
|
controller.pre_execute(self) |
|
|
|
def post_execute_controllers(self): |
|
"""Post-executes all controllers in order of priority.""" |
|
for controller in self.controllers: |
|
controller.post_execute(self) |
|
|
|
def cpu(self): |
|
"""Puts models to CPU.""" |
|
for name in self.models: |
|
self.models[name].cpu() |
|
|
|
def cuda(self): |
|
"""Puts models to CUDA.""" |
|
for name in self.models: |
|
self.models[name].cuda() |
|
|
|
def set_model_requires_grad(self, name, requires_grad): |
|
"""Sets the `requires_grad` configuration for a particular model.""" |
|
for param in self.models[name].parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
def set_models_requires_grad(self, requires_grad): |
|
"""Sets the `requires_grad` configuration for all models.""" |
|
for name in self.models: |
|
self.set_model_requires_grad(name, requires_grad) |
|
|
|
def set_model_mode(self, name, mode): |
|
"""Sets the `train/val` mode for a particular model.""" |
|
if isinstance(mode, str): |
|
mode = mode.lower() |
|
if mode == 'train' or mode is True: |
|
self.models[name].train() |
|
elif mode in ['val', 'test', 'eval'] or mode is False: |
|
self.models[name].eval() |
|
else: |
|
raise ValueError(f'Invalid model mode `{mode}`!') |
|
|
|
def set_mode(self, mode): |
|
"""Sets the `train/val` mode for all models.""" |
|
self.mode = mode |
|
for name in self.models: |
|
self.set_model_mode(name, mode) |
|
|
|
def train_step(self, data, **train_kwargs): |
|
"""Executes one training step.""" |
|
raise NotImplementedError('Should be implemented in derived class.') |
|
|
|
def train(self, **train_kwargs): |
|
"""Training function.""" |
|
self.set_mode('train') |
|
self.distribute() |
|
self.build_dataset('train') |
|
self.build_loss() |
|
|
|
self.logger.print() |
|
self.logger.info(f'Start training.') |
|
if self.total_iters == 0: |
|
total_epochs = self.config.get('total_epochs', 0) |
|
self.total_iters = self.convert_epoch_to_iter(total_epochs) |
|
assert self.total_iters > 0 |
|
while self.iter < self.total_iters: |
|
self._iter += 1 |
|
self.pre_execute_controllers() |
|
data_batch = next(self.train_loader) |
|
self.timer.pre_execute(self) |
|
for key in data_batch: |
|
assert data_batch[key].shape[0] == self.batch_size |
|
data_batch[key] = data_batch[key].cuda( |
|
torch.cuda.current_device(), non_blocking=True) |
|
self.train_step(data_batch, **train_kwargs) |
|
self.seen_img += self.batch_size * self.world_size |
|
self.timer.post_execute(self) |
|
self.post_execute_controllers() |
|
self.finish() |
|
|
|
def val(self, **val_kwargs): |
|
"""Validation function.""" |
|
raise NotImplementedError('Should be implemented in derived class.') |
|
|
|
def save(self, |
|
filepath, |
|
running_metadata=True, |
|
learning_rate=True, |
|
optimizer=True, |
|
running_stats=False): |
|
"""Saves the current running status. |
|
Args: |
|
filepath: File path to save the checkpoint. |
|
running_metadata: Whether to save the running metadata, such as |
|
batch size, current iteration, etc. (default: True) |
|
learning_rate: Whether to save the learning rate. (default: True) |
|
optimizer: Whether to save the optimizer. (default: True) |
|
running_stats: Whether to save the running stats. (default: False) |
|
""" |
|
checkpoint = dict() |
|
|
|
checkpoint['models'] = dict() |
|
for name, model in self.models.items(): |
|
checkpoint['models'][name] = self.get_module(model).state_dict() |
|
|
|
if running_metadata: |
|
checkpoint['running_metadata'] = { |
|
'iter': self.iter, |
|
'seen_img': self.seen_img, |
|
} |
|
|
|
if optimizer: |
|
checkpoint['optimizers'] = dict() |
|
for opt_name, opt in self.optimizers.items(): |
|
checkpoint['optimizers'][opt_name] = opt.state_dict() |
|
|
|
if learning_rate: |
|
checkpoint['learning_rates'] = dict() |
|
for lr_name, lr in self.lr_schedulers.items(): |
|
checkpoint['learning_rates'][lr_name] = lr.state_dict() |
|
|
|
|
|
if running_stats: |
|
checkpoint['running_stats'] = self.running_stats |
|
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True) |
|
torch.save(checkpoint, filepath) |
|
self.logger.info(f'Successfully saved checkpoint to `{filepath}`.') |
|
|
|
def load(self, |
|
filepath, |
|
running_metadata=True, |
|
learning_rate=True, |
|
optimizer=True, |
|
running_stats=False, |
|
map_location='cpu'): |
|
"""Loads previous running status. |
|
|
|
Args: |
|
filepath: File path to load the checkpoint. |
|
running_metadata: Whether to load the running metadata, such as |
|
batch size, current iteration, etc. (default: True) |
|
learning_rate: Whether to load the learning rate. (default: True) |
|
optimizer: Whether to load the optimizer. (default: True) |
|
running_stats: Whether to load the running stats. (default: False) |
|
map_location: Map location used for model loading. (default: `cpu`) |
|
""" |
|
self.logger.info(f'Resuming from checkpoint `{filepath}` ...') |
|
if not os.path.isfile(filepath): |
|
raise IOError(f'Checkpoint `{filepath}` does not exist!') |
|
map_location = map_location.lower() |
|
assert map_location in ['cpu', 'gpu'] |
|
if map_location == 'gpu': |
|
device = torch.cuda.current_device() |
|
map_location = lambda storage, location: storage.cuda(device) |
|
checkpoint = torch.load(filepath, map_location=map_location) |
|
|
|
if 'models' not in checkpoint: |
|
checkpoint = {'models': checkpoint} |
|
for model_name, model in self.models.items(): |
|
if model_name not in checkpoint['models']: |
|
self.logger.warning(f'Model `{model_name}` is not included in ' |
|
f'the checkpoint, and hence will NOT be ' |
|
f'loaded!') |
|
continue |
|
state_dict = _strip_state_dict_prefix( |
|
checkpoint['models'][model_name]) |
|
model.load_state_dict(state_dict) |
|
self.logger.info(f' Successfully loaded model `{model_name}`.') |
|
|
|
if running_metadata: |
|
if 'running_metadata' not in checkpoint: |
|
self.logger.warning(f'Running metadata is not included in the ' |
|
f'checkpoint, and hence will NOT be ' |
|
f'loaded!') |
|
else: |
|
self._iter = checkpoint['running_metadata']['iter'] |
|
self._start_iter = self._iter |
|
self.seen_img = checkpoint['running_metadata']['seen_img'] |
|
|
|
if optimizer: |
|
if 'optimizers' not in checkpoint: |
|
self.logger.warning(f'Optimizers are not included in the ' |
|
f'checkpoint, and hence will NOT be ' |
|
f'loaded!') |
|
else: |
|
for opt_name, opt in self.optimizers.items(): |
|
if opt_name not in checkpoint['optimizers']: |
|
self.logger.warning(f'Optimizer `{opt_name}` is not ' |
|
f'included in the checkpoint, and ' |
|
f'hence will NOT be loaded!') |
|
continue |
|
opt.load_state_dict(checkpoint['optimizers'][opt_name]) |
|
self.logger.info(f' Successfully loaded optimizer ' |
|
f'`{opt_name}`.') |
|
|
|
if learning_rate: |
|
if 'learning_rates' not in checkpoint: |
|
self.logger.warning(f'Learning rates are not included in the ' |
|
f'checkpoint, and hence will NOT be ' |
|
f'loaded!') |
|
else: |
|
for lr_name, lr in self.lr_schedulers.items(): |
|
if lr_name not in checkpoint['learning_rates']: |
|
self.logger.warning(f'Learning rate `{lr_name}` is not ' |
|
f'included in the checkpoint, and ' |
|
f'hence will NOT be loaded!') |
|
continue |
|
lr.load_state_dict(checkpoint['learning_rates'][lr_name]) |
|
self.logger.info(f' Successfully loaded learning rate ' |
|
f'`{lr_name}`.') |
|
|
|
if running_stats: |
|
if 'running_stats' not in checkpoint: |
|
self.logger.warning(f'Running stats is not included in the ' |
|
f'checkpoint, and hence will NOT be ' |
|
f'loaded!') |
|
else: |
|
self.running_stats = deepcopy(checkpoint['running_stats']) |
|
self.logger.info(f' Successfully loaded running stats.') |
|
|
|
tailing_message = '' |
|
if running_metadata and 'running_metadata' in checkpoint: |
|
tailing_message = f' (iteration {self.iter})' |
|
self.logger.info(f'Successfully resumed from checkpoint `{filepath}`.' |
|
f'{tailing_message}') |
|
|