File size: 3,536 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from fairseq import optim
from omegaconf import DictConfig

logger = logging.getLogger(__name__)


class AMPOptimizer(optim.FairseqOptimizer):
    """
    Wrap an *optimizer* to support AMP (automatic mixed precision) training.
    """

    def __init__(self, cfg: DictConfig, params, fp32_optimizer, **kwargs):
        super().__init__(cfg.optimizer)
        self.fp32_optimizer = fp32_optimizer
        amp_kwargs = {"init_scale": cfg.common.fp16_init_scale}
        if getattr(cfg.common, "amp_scale_window", None) is not None:
            amp_kwargs["growth_interval"] = cfg.common.amp_init_scale
        self._grad_scaler = torch.cuda.amp.GradScaler(**amp_kwargs)
        self.min_loss_scale = cfg.common.min_loss_scale

    @classmethod
    def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
        """
        Args:
            cfg (omegaconf.DictConfig): fairseq args
            params (iterable): iterable of parameters to optimize
        """
        fp32_optimizer = optim.build_optimizer(cfg.optimizer, params)
        return cls(cfg, params, fp32_optimizer, **kwargs)

    def backward(self, loss):
        """Computes the sum of gradients of the given tensor w.r.t. graph leaves.

        Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
        function additionally dynamically scales the loss to avoid gradient
        underflow.
        """
        self._grad_scaler.scale(loss).backward()

    def step(self):
        self.scaler.step(self.fp32_optimizer)
        self.scaler.update()

    def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
        """Clips gradient norm."""
        self.scaler.unscale_(self.optimizer)
        grad_norm = self.fp32_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)
        if not torch.isfinite(grad_norm).all():
            new_loss_scale = self.next_loss_scale
            if new_loss_scale <= self.min_loss_scale:
                raise FloatingPointError(
                    (
                        "AMP: Minimum loss scale reached ({}). Your loss is probably exploding. "
                        "Try restarting training or use fp32. {}"
                    ).format(self.min_loss_scale, new_loss_scale)
                )
            else:
                logger.info(
                    "AMP: overflow detected, setting scale to " f"to {new_loss_scale}"
                )
        return grad_norm

    @property
    def scaler(self):
        return self._grad_scaler

    @property
    def next_loss_scale(self):
        return self.scaler.get_scale() * self.scaler.get_backoff_factor()

    @property
    def optimizer(self):
        return self.fp32_optimizer.optimizer

    @optimizer.setter
    def optimizer(self, optimizer):
        self.fp32_optimizer.optimizer = optimizer

    @property
    def lr_scheduler(self):
        return getattr(self.fp32_optimizer, "lr_scheduler", None)

    @property
    def optimizer_config(self):
        return self.fp32_optimizer.optimizer_config

    def get_lr(self):
        return self.fp32_optimizer.get_lr()

    def set_lr(self, lr):
        self.fp32_optimizer.set_lr(lr)

    def all_reduce_grads(self, module):
        self.fp32_optimizer.all_reduce_grads(module)

    @property
    def supports_flat_params(self):
        return self.fp32_optimizer.supports_flat_params