Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import logging | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| _logger = logging.getLogger(__name__) | |
| class Callback: | |
| """ | |
| Callback provides an easy way to react to events like begin/end of epochs. | |
| """ | |
| def __init__(self): | |
| self.model = None | |
| self.mutator = None | |
| self.trainer = None | |
| def build(self, model, mutator, trainer): | |
| """ | |
| Callback needs to be built with model, mutator, trainer, to get updates from them. | |
| Parameters | |
| ---------- | |
| model : nn.Module | |
| Model to be trained. | |
| mutator : nn.Module | |
| Mutator that mutates the model. | |
| trainer : BaseTrainer | |
| Trainer that is to call the callback. | |
| """ | |
| self.model = model | |
| self.mutator = mutator | |
| self.trainer = trainer | |
| def on_epoch_begin(self, epoch): | |
| """ | |
| Implement this to do something at the begin of epoch. | |
| Parameters | |
| ---------- | |
| epoch : int | |
| Epoch number, starting from 0. | |
| """ | |
| pass | |
| def on_epoch_end(self, epoch): | |
| """ | |
| Implement this to do something at the end of epoch. | |
| Parameters | |
| ---------- | |
| epoch : int | |
| Epoch number, starting from 0. | |
| """ | |
| pass | |
| def on_batch_begin(self, epoch): | |
| pass | |
| def on_batch_end(self, epoch): | |
| pass | |
| class LRSchedulerCallback(Callback): | |
| """ | |
| Calls scheduler on every epoch ends. | |
| Parameters | |
| ---------- | |
| scheduler : LRScheduler | |
| Scheduler to be called. | |
| """ | |
| def __init__(self, scheduler, mode="epoch"): | |
| super().__init__() | |
| assert mode == "epoch" | |
| self.scheduler = scheduler | |
| self.mode = mode | |
| def on_epoch_end(self, epoch): | |
| """ | |
| Call ``self.scheduler.step()`` on epoch end. | |
| """ | |
| self.scheduler.step() | |
| class ArchitectureCheckpoint(Callback): | |
| """ | |
| Calls ``trainer.export()`` on every epoch ends. | |
| Parameters | |
| ---------- | |
| checkpoint_dir : str | |
| Location to save checkpoints. | |
| """ | |
| def __init__(self, checkpoint_dir): | |
| super().__init__() | |
| self.checkpoint_dir = checkpoint_dir | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| def on_epoch_end(self, epoch): | |
| """ | |
| Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end. | |
| """ | |
| dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)) | |
| _logger.info("Saving architecture to %s", dest_path) | |
| self.trainer.export(dest_path) | |
| class ModelCheckpoint(Callback): | |
| """ | |
| Calls ``trainer.export()`` on every epoch ends. | |
| Parameters | |
| ---------- | |
| checkpoint_dir : str | |
| Location to save checkpoints. | |
| """ | |
| def __init__(self, checkpoint_dir): | |
| super().__init__() | |
| self.checkpoint_dir = checkpoint_dir | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| def on_epoch_end(self, epoch): | |
| """ | |
| Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end. | |
| ``DataParallel`` object will have their inside modules exported. | |
| """ | |
| if isinstance(self.model, nn.DataParallel): | |
| state_dict = self.model.module.state_dict() | |
| else: | |
| state_dict = self.model.state_dict() | |
| dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) | |
| _logger.info("Saving model to %s", dest_path) | |
| torch.save(state_dict, dest_path) | |