Spaces:
Runtime error
Runtime error
| from typing import * | |
| import torch | |
| import numpy as np | |
| import torch.utils | |
| class AdaptiveGradClipper: | |
| """ | |
| Adaptive gradient clipping for training. | |
| """ | |
| def __init__( | |
| self, | |
| max_norm=None, | |
| clip_percentile=95.0, | |
| buffer_size=1000, | |
| ): | |
| self.max_norm = max_norm | |
| self.clip_percentile = clip_percentile | |
| self.buffer_size = buffer_size | |
| self._grad_norm = np.zeros(buffer_size, dtype=np.float32) | |
| self._max_norm = max_norm | |
| self._buffer_ptr = 0 | |
| self._buffer_length = 0 | |
| def __repr__(self): | |
| return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' | |
| def state_dict(self): | |
| return { | |
| 'grad_norm': self._grad_norm, | |
| 'max_norm': self._max_norm, | |
| 'buffer_ptr': self._buffer_ptr, | |
| 'buffer_length': self._buffer_length, | |
| } | |
| def load_state_dict(self, state_dict): | |
| self._grad_norm = state_dict['grad_norm'] | |
| self._max_norm = state_dict['max_norm'] | |
| self._buffer_ptr = state_dict['buffer_ptr'] | |
| self._buffer_length = state_dict['buffer_length'] | |
| def log(self): | |
| return { | |
| 'max_norm': self._max_norm, | |
| } | |
| def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): | |
| """Clip the gradient norm of an iterable of parameters. | |
| The norm is computed over all gradients together, as if they were | |
| concatenated into a single vector. Gradients are modified in-place. | |
| Args: | |
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | |
| single Tensor that will have gradients normalized | |
| norm_type (float): type of the used p-norm. Can be ``'inf'`` for | |
| infinity norm. | |
| error_if_nonfinite (bool): if True, an error is thrown if the total | |
| norm of the gradients from :attr:`parameters` is ``nan``, | |
| ``inf``, or ``-inf``. Default: False (will switch to True in the future) | |
| foreach (bool): use the faster foreach-based implementation. | |
| If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently | |
| fall back to the slow implementation for other device types. | |
| Default: ``None`` | |
| Returns: | |
| Total norm of the parameter gradients (viewed as a single vector). | |
| """ | |
| max_norm = self._max_norm if self._max_norm is not None else float('inf') | |
| grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) | |
| if torch.isfinite(grad_norm): | |
| self._grad_norm[self._buffer_ptr] = grad_norm | |
| self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size | |
| self._buffer_length = min(self._buffer_length + 1, self.buffer_size) | |
| if self._buffer_length == self.buffer_size: | |
| self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) | |
| self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm | |
| return grad_norm |