|
from typing import * |
|
import torch |
|
import numpy as np |
|
import torch.utils |
|
|
|
|
|
class AdaptiveGradClipper: |
|
""" |
|
Adaptive gradient clipping for training. |
|
""" |
|
def __init__( |
|
self, |
|
max_norm=None, |
|
clip_percentile=95.0, |
|
buffer_size=1000, |
|
): |
|
self.max_norm = max_norm |
|
self.clip_percentile = clip_percentile |
|
self.buffer_size = buffer_size |
|
|
|
self._grad_norm = np.zeros(buffer_size, dtype=np.float32) |
|
self._max_norm = max_norm |
|
self._buffer_ptr = 0 |
|
self._buffer_length = 0 |
|
|
|
def __repr__(self): |
|
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' |
|
|
|
def state_dict(self): |
|
return { |
|
'grad_norm': self._grad_norm, |
|
'max_norm': self._max_norm, |
|
'buffer_ptr': self._buffer_ptr, |
|
'buffer_length': self._buffer_length, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self._grad_norm = state_dict['grad_norm'] |
|
self._max_norm = state_dict['max_norm'] |
|
self._buffer_ptr = state_dict['buffer_ptr'] |
|
self._buffer_length = state_dict['buffer_length'] |
|
|
|
def log(self): |
|
return { |
|
'max_norm': self._max_norm, |
|
} |
|
|
|
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): |
|
"""Clip the gradient norm of an iterable of parameters. |
|
|
|
The norm is computed over all gradients together, as if they were |
|
concatenated into a single vector. Gradients are modified in-place. |
|
|
|
Args: |
|
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a |
|
single Tensor that will have gradients normalized |
|
norm_type (float): type of the used p-norm. Can be ``'inf'`` for |
|
infinity norm. |
|
error_if_nonfinite (bool): if True, an error is thrown if the total |
|
norm of the gradients from :attr:`parameters` is ``nan``, |
|
``inf``, or ``-inf``. Default: False (will switch to True in the future) |
|
foreach (bool): use the faster foreach-based implementation. |
|
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently |
|
fall back to the slow implementation for other device types. |
|
Default: ``None`` |
|
|
|
Returns: |
|
Total norm of the parameter gradients (viewed as a single vector). |
|
""" |
|
max_norm = self._max_norm if self._max_norm is not None else float('inf') |
|
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) |
|
|
|
if torch.isfinite(grad_norm): |
|
self._grad_norm[self._buffer_ptr] = grad_norm |
|
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size |
|
self._buffer_length = min(self._buffer_length + 1, self.buffer_size) |
|
if self._buffer_length == self.buffer_size: |
|
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) |
|
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm |
|
|
|
return grad_norm |