# coding=utf-8 """PyTorch optimization for BERT model.""" from apex.optimizers import FP16_Optimizer class FP16_Optimizer_State(FP16_Optimizer): def __init__(self, init_optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True): super(FP16_Optimizer_State, self).__init__(init_optimizer, static_loss_scale, dynamic_loss_scale, dynamic_loss_args, verbose) def state_dict(self): """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict of the contained Pytorch optimizer. Example:: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth") """ state_dict = {} state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['cur_scale'] = self.cur_scale state_dict['cur_iter'] = self.cur_iter if state_dict['dynamic_loss_scale']: state_dict['last_overflow_iter'] = self.last_overflow_iter state_dict['scale_factor'] = self.scale_factor state_dict['scale_window'] = self.scale_window state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['fp32_groups_flat'] = self.fp32_groups_flat return state_dict def load_state_dict(self, state_dict): """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.cur_scale = state_dict['cur_scale'] self.cur_iter = state_dict['cur_iter'] if state_dict['dynamic_loss_scale']: self.last_overflow_iter = state_dict['last_overflow_iter'] self.scale_factor = state_dict['scale_factor'] self.scale_window = state_dict['scale_window'] self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still # out of date. There are two options. # 1: Refresh the master params from the model's fp16 params. # This requires less storage but incurs precision loss. # 2: Save and restore the fp32 master copies separately. # We choose option 2. # # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device # of their associated parameters, because it's possible those buffers might not exist yet in # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # constructed in the same way as the one whose state_dict we are loading, the same master params # are guaranteed to exist, so we can just copy_() from the saved master params. for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']): current.data.copy_(saved.data)