Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
"""PyTorch optimization for BERT model.""" | |
from apex.optimizers import FP16_Optimizer | |
class FP16_Optimizer_State(FP16_Optimizer): | |
def __init__(self, | |
init_optimizer, | |
static_loss_scale=1.0, | |
dynamic_loss_scale=False, | |
dynamic_loss_args=None, | |
verbose=True): | |
super(FP16_Optimizer_State, self).__init__(init_optimizer, | |
static_loss_scale, dynamic_loss_scale, dynamic_loss_args, verbose) | |
def state_dict(self): | |
""" | |
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. | |
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict | |
of the contained Pytorch optimizer. | |
Example:: | |
checkpoint = {} | |
checkpoint['model'] = model.state_dict() | |
checkpoint['optimizer'] = optimizer.state_dict() | |
torch.save(checkpoint, "saved.pth") | |
""" | |
state_dict = {} | |
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale | |
state_dict['cur_scale'] = self.cur_scale | |
state_dict['cur_iter'] = self.cur_iter | |
if state_dict['dynamic_loss_scale']: | |
state_dict['last_overflow_iter'] = self.last_overflow_iter | |
state_dict['scale_factor'] = self.scale_factor | |
state_dict['scale_window'] = self.scale_window | |
state_dict['optimizer_state_dict'] = self.optimizer.state_dict() | |
state_dict['fp32_groups_flat'] = self.fp32_groups_flat | |
return state_dict | |
def load_state_dict(self, state_dict): | |
""" | |
Loads a state_dict created by an earlier call to state_dict(). | |
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, | |
whose parameters in turn came from ``model``, it is expected that the user | |
will call ``model.load_state_dict()`` before | |
``fp16_optimizer_instance.load_state_dict()`` is called. | |
Example:: | |
model = torch.nn.Linear(D_in, D_out).cuda().half() | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |
... | |
checkpoint = torch.load("saved.pth") | |
model.load_state_dict(checkpoint['model']) | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
""" | |
# I think it should actually be ok to reload the optimizer before the model. | |
self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] | |
self.cur_scale = state_dict['cur_scale'] | |
self.cur_iter = state_dict['cur_iter'] | |
if state_dict['dynamic_loss_scale']: | |
self.last_overflow_iter = state_dict['last_overflow_iter'] | |
self.scale_factor = state_dict['scale_factor'] | |
self.scale_window = state_dict['scale_window'] | |
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |
# At this point, the optimizer's references to the model's fp32 parameters are up to date. | |
# The optimizer's hyperparameters and internal buffers are also up to date. | |
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still | |
# out of date. There are two options. | |
# 1: Refresh the master params from the model's fp16 params. | |
# This requires less storage but incurs precision loss. | |
# 2: Save and restore the fp32 master copies separately. | |
# We choose option 2. | |
# | |
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device | |
# of their associated parameters, because it's possible those buffers might not exist yet in | |
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been | |
# constructed in the same way as the one whose state_dict we are loading, the same master params | |
# are guaranteed to exist, so we can just copy_() from the saved master params. | |
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']): | |
current.data.copy_(saved.data) | |