Spaces:
Build error
Build error
| # This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py | |
| from __future__ import division | |
| import logging | |
| from utils import CheckpointSaver | |
| from tensorboardX import SummaryWriter | |
| import torch | |
| from tqdm import tqdm | |
| tqdm.monitor_interval = 0 | |
| logger = logging.getLogger(__name__) | |
| class BaseTrainer(object): | |
| """Base class for Trainer objects. | |
| Takes care of checkpointing/logging/resuming training. | |
| """ | |
| def __init__(self, options): | |
| self.options = options | |
| if options.multiprocessing_distributed: | |
| self.device = torch.device('cuda', options.gpu) | |
| else: | |
| self.device = torch.device( | |
| 'cuda' if torch.cuda.is_available() else 'cpu') | |
| # override this function to define your model, optimizers etc. | |
| self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, | |
| overwrite=options.overwrite) | |
| if options.rank == 0: | |
| self.summary_writer = SummaryWriter(self.options.summary_dir) | |
| self.init_fn() | |
| self.checkpoint = None | |
| if options.resume and self.saver.exists_checkpoint(): | |
| self.checkpoint = self.saver.load_checkpoint( | |
| self.models_dict, self.optimizers_dict) | |
| if self.checkpoint is None: | |
| self.epoch_count = 0 | |
| self.step_count = 0 | |
| else: | |
| self.epoch_count = self.checkpoint['epoch'] | |
| self.step_count = self.checkpoint['total_step_count'] | |
| if self.checkpoint is not None: | |
| self.checkpoint_batch_idx = self.checkpoint['batch_idx'] | |
| else: | |
| self.checkpoint_batch_idx = 0 | |
| self.best_performance = float('inf') | |
| def load_pretrained(self, checkpoint_file=None): | |
| """Load a pretrained checkpoint. | |
| This is different from resuming training using --resume. | |
| """ | |
| if checkpoint_file is not None: | |
| checkpoint = torch.load(checkpoint_file) | |
| for model in self.models_dict: | |
| if model in checkpoint: | |
| self.models_dict[model].load_state_dict(checkpoint[model], | |
| strict=True) | |
| print(f'Checkpoint {model} loaded') | |
| def move_dict_to_device(self, dict, device, tensor2float=False): | |
| for k, v in dict.items(): | |
| if isinstance(v, torch.Tensor): | |
| if tensor2float: | |
| dict[k] = v.float().to(device) | |
| else: | |
| dict[k] = v.to(device) | |
| # The following methods (with the possible exception of test) have to be implemented in the derived classes | |
| def train(self, epoch): | |
| raise NotImplementedError('You need to provide an train method') | |
| def init_fn(self): | |
| raise NotImplementedError('You need to provide an _init_fn method') | |
| def train_step(self, input_batch): | |
| raise NotImplementedError('You need to provide a _train_step method') | |
| def train_summaries(self, input_batch): | |
| raise NotImplementedError( | |
| 'You need to provide a _train_summaries method') | |
| def visualize(self, input_batch): | |
| raise NotImplementedError('You need to provide a visualize method') | |
| def validate(self): | |
| pass | |
| def test(self): | |
| pass | |
| def evaluate(self): | |
| pass | |
| def fit(self): | |
| # Run training for num_epochs epochs | |
| for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), | |
| total=self.options.num_epochs, | |
| initial=self.epoch_count): | |
| self.epoch_count = epoch | |
| self.train(epoch) | |
| return | |