Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| class DynamicLossScaler(object): | |
| def __init__( | |
| self, | |
| init_scale=2.0 ** 15, | |
| scale_factor=2.0, | |
| scale_window=2000, | |
| tolerance=0.0, | |
| threshold=None, | |
| min_loss_scale=1e-4, | |
| ): | |
| self.loss_scale = init_scale | |
| self.scale_factor = scale_factor | |
| self.scale_window = scale_window | |
| self.tolerance = tolerance | |
| self.threshold = threshold | |
| self._iter = 0 | |
| self._last_overflow_iter = -1 | |
| self._last_rescale_iter = -1 | |
| self._overflows_since_rescale = 0 | |
| self.min_loss_scale = min_loss_scale | |
| def scale(self, outputs): | |
| return self.loss_scale * outputs | |
| def update(self): | |
| if (self._iter - self._last_overflow_iter) % self.scale_window == 0: | |
| self.loss_scale *= self.scale_factor | |
| self._last_rescale_iter = self._iter | |
| self._iter += 1 | |
| def _decrease_loss_scale(self): | |
| self.loss_scale /= self.scale_factor | |
| if self.threshold is not None: | |
| self.loss_scale = max(self.loss_scale, self.threshold) | |
| def check_overflow(self, grad_norm): | |
| # detect inf and nan | |
| if grad_norm == float("inf") or grad_norm != grad_norm: | |
| # overflow has occured | |
| prev_scale = self.loss_scale | |
| iter_since_rescale = self._iter - self._last_rescale_iter | |
| self._last_overflow_iter = self._iter | |
| self._overflows_since_rescale += 1 | |
| pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) | |
| if pct_overflow >= self.tolerance: | |
| self._decrease_loss_scale() | |
| self._last_rescale_iter = self._iter | |
| self._overflows_since_rescale = 0 | |
| if self.loss_scale <= self.min_loss_scale: | |
| # Use FloatingPointError as an uncommon error that parent | |
| # functions can safely catch to stop training. | |
| self.loss_scale = prev_scale | |
| raise FloatingPointError( | |
| ( | |
| "Minimum loss scale reached ({}). Your loss is probably exploding. " | |
| "Try lowering the learning rate, using gradient clipping or " | |
| "increasing the batch size." | |
| ).format(self.min_loss_scale) | |
| ) | |
| self._iter += 1 | |
| raise OverflowError("setting loss scale to: " + str(self.loss_scale)) | |