File size: 10,593 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# python3.7
"""Contains the running controller to adjust the learing rate."""

from torch.optim import lr_scheduler

from .base_controller import BaseController

__all__ = ['build_lr_scheduler', 'LRScheduler']


class BaseWarmUpLR(lr_scheduler._LRScheduler):  # pylint: disable=protected-access
    """Defines a base learning rate scheduler with warm-up.

    NOTE: Different from the official LRSchedulers, the base unit for learning
    rate update is always set as `iteration` instead of `epoch`. Hence, the
    number of epochs should be converted to number of iterations before using.
    """

    def __init__(self,
                 optimizer,
                 warmup_type='NO',
                 warmup_iters=0,
                 warmup_factor=0.1):
        """Initializes the scheduler with warm-up settings.

        Following warm-up types are supported:

        (1) `NO`: Do not use warm-up.
        (2) `CONST`: Use a constant value for warm-up.
        (3) `LINEAR`: Increase the learning rate linearly.
        (4) `EXP`: Increase the learning rate exponentionally.

        Whatever warm-type is used, the initial learning rate for warm-up (if
        needed) is always set as `base_lr * warmup_factor`.

        Args:
            optimizer: The optimizer for applying gradients.
            warmup_type: The warm-up type. (default: `NO`)
            warmup_iters: Iterations for warm-up. (default: 0)
            warmup_factor: Factor to set the intital learning rate for warm-up.
                (default: 0.1)
        """
        self._warmup_type = warmup_type.upper()
        assert self.warmup_type in ['NO', 'CONST', 'LINEAR', 'EXP']
        self._warmup_iters = warmup_iters
        self._warmup_factor = float(warmup_factor)
        super().__init__(optimizer, last_epoch=-1)

    @property
    def warmup_type(self):
        """Gets the warm-up type."""
        return self._warmup_type

    @property
    def warmup_iters(self):
        """Gets the iterations for warm-up."""
        return self._warmup_iters

    @property
    def warmup_factor(self):
        """Gets the warm-up factor."""
        return self._warmup_factor

    def get_warmup_lr(self):
        """Gets learning rate at the warm-up stage."""
        progress = self.last_epoch / self.warmup_iters
        if self.warmup_type == 'NO':
            return self.base_lrs
        if self.warmup_type == 'CONST':
            return [lr * self.warmup_factor for lr in self.base_lrs]
        if self.warmup_type == 'LINEAR':
            scale = (1 - progress) * (1 - self.warmup_factor)
            return [lr * (1 - scale) for lr in self.base_lrs]
        if self.warmup_type == 'EXP':
            scale = self.warmup_factor ** (1 - progress)
            return [lr * scale for lr in self.base_lrs]
        raise ValueError(f'Invalid warm-up type `{self.warmup_type}`!')

    def _get_lr(self):
        """Gets the learning rate ignoring warm-up."""
        raise NotImplementedError(f'Should be implemented in derived classes!')

    def get_lr(self):
        if self.last_epoch < self.warmup_iters:
            return self.get_warmup_lr()
        return self._get_lr()


class FixedWarmUpLR(BaseWarmUpLR):
    """Defines a warm-up LRScheduler with fixed learning rate."""

    def _get_lr(self):
        return self.base_lrs


class StepWarmUpLR(BaseWarmUpLR):
    """Defines a warm-up LRScheduler with periodically decayed learning rate.

    In particular, the learning rate will be decayed with factor `decay_factor`
    every `decay_step` iterations.

    If the `decay_step` is a list of integers, the learning rate will be
    adjusted at those particular iterations.
    """

    def __init__(self,
                 optimizer,
                 decay_step,
                 decay_factor=0.1,
                 warmup_type='NO',
                 warmup_iters=0,
                 warmup_factor=0.1):
        self._decay_step = decay_step
        self._decay_factor = decay_factor
        super().__init__(optimizer, warmup_type, warmup_iters, warmup_factor)

    @property
    def decay_step(self):
        """Gets the decay step."""
        return self._decay_step

    @property
    def decay_factor(self):
        """Gets the decay factor."""
        return self._decay_factor

    def _get_lr(self):
        if isinstance(self.decay_step, int):
            scale = self.decay_factor ** (self.last_epoch // self.decay_step)
            return [lr * scale for lr in self.base_lrs]
        if isinstance(self.decay_step, (list, tuple)):
            bucket_id = 0
            for step in set(self.decay_step):
                if self.last_epoch >= step:
                    bucket_id += 1
            scale = self.decay_factor ** bucket_id
            return [lr * scale for lr in self.base_lrs]
        raise TypeError(f'Type of LR decay step can only be integer, list, '
                        f'or tuple, but `{type(self.decay_step)}` is received!')


class EXPStepWarmUpLR(BaseWarmUpLR):
    """Defines a warm-up LRScheduler with exponentially decayed learning rate.

    In particular, the learning rate will be decayed with factor `decay_factor`
    every `decay_step` iterations.

    If the `decay_step` is a list of integers, the learning rate will be
    adjusted at those particular iterations.
    """
    def __init__(self,
                 optimizer,
                 decay_step,
                 decay_factor=0.1,
                 warmup_type='NO',
                 warmup_iters=0,
                 warmup_factor=0.1):
        self._decay_step = decay_step
        self._decay_factor = decay_factor
        super().__init__(optimizer, warmup_type, warmup_iters, warmup_factor)

    @property
    def decay_step(self):
        """Gets the decay step."""
        return self._decay_step

    @property
    def decay_factor(self):
        """Gets the decay factor."""
        return self._decay_factor

    def _get_lr(self):
        if isinstance(self.decay_step, int):
            scale = self.decay_factor ** (self.last_epoch / self.decay_step)
            return [lr * scale for lr in self.base_lrs]
        if isinstance(self.decay_step, (list, tuple)):
            bucket_id = 0
            for step in set(self.decay_step):
                if self.last_epoch >= step:
                    bucket_id += 1
            scale = self.decay_factor ** bucket_id
            return [lr * scale for lr in self.base_lrs]
        raise TypeError(f'Type of LR decay step can only be integer, list, '
                        f'or tuple, but `{type(self.decay_step)}` is received!')


_ALLOWED_LR_TYPES = ['FIXED', 'STEP', 'EXPSTEP']


def build_lr_scheduler(config, optimizer):
    """Builds a learning rate scheduler for the given optimizer.

    Basically, the configuration is expected to contain following settings:

    (1) lr_type: The type of the learning rate scheduler. (required)
    (2) warmup_type: The warm-up type. (default: `NO`)
    (3) warmup_iters: Iterations for warm-up. (default: 0)
    (4) warmup_factor: Factor to set the intital learning rate for warm-up.
        (default: 0.1)
    (5) **kwargs: Additional settings for the scheduler.

    Args:
        config: The configuration used to build the learning rate scheduler.
        optimizer: The optimizer which the scheduler serves.

    Returns:
        A `BaseWarmUpLR` class.

    Raises:
        ValueError: The `lr_type` is not supported.
        NotImplementedError: If `lr_type` is not implemented.
    """
    assert isinstance(config, dict)
    lr_type = config['lr_type'].upper()
    warmup_type = config.get('warmup_type', 'NO')
    warmup_iters = config.get('warmup_iters', 0)
    warmup_factor = config.get('warmup_factor', 0.1)

    if lr_type not in _ALLOWED_LR_TYPES:
        raise ValueError(f'Invalid learning rate scheduler type `{lr_type}`!'
                         f'Allowed types: {_ALLOWED_LR_TYPES}.')

    if lr_type == 'FIXED':
        return FixedWarmUpLR(optimizer=optimizer,
                             warmup_type=warmup_type,
                             warmup_iters=warmup_iters,
                             warmup_factor=warmup_factor)
    if lr_type == 'STEP':
        return StepWarmUpLR(optimizer=optimizer,
                            decay_step=config['decay_step'],
                            decay_factor=config.get('decay_factor', 0.1),
                            warmup_type=warmup_type,
                            warmup_iters=warmup_iters,
                            warmup_factor=warmup_factor)
    if lr_type == 'EXPSTEP':
        return EXPStepWarmUpLR(optimizer=optimizer,
                               decay_step=config['decay_step'],
                               decay_factor=config.get('decay_factor', 0.1),
                               warmup_type=warmup_type,
                               warmup_iters=warmup_iters,
                               warmup_factor=warmup_factor)
    raise NotImplementedError(f'Not implemented scheduler type `{lr_type}`!')


class LRScheduler(BaseController):
    """Defines the running controller to adjust the learning rate.

    This controller will be executed after every iteration.

    NOTE: The controller is set to `FIRST` priority.
    """

    def __init__(self, lr_config):
        assert isinstance(lr_config, dict)
        config = {
            'priority': 'FIRST',
            'every_n_iters': 1,
        }
        super().__init__(config)
        self._lr_config = lr_config.copy()

    @property
    def lr_config(self):
        """Gets the configuration for learning rate scheduler."""
        return self._lr_config

    def setup(self, runner):
        for name, config in self.lr_config.items():
            if not name or not config:
                continue
            if name in runner.lr_schedulers:
                raise AttributeError(f'LR Scheduler `{name}` already existed!')
            if name not in runner.optimizers:
                raise AttributeError(f'Optimizer `{name}` is missing!')
            runner.lr_schedulers[name] = build_lr_scheduler(
                config, runner.optimizers[name])
            runner.running_stats.add(
                f'lr_{name}', log_format='.3e', log_name=f'lr ({name})',
                log_strategy='CURRENT')

    def execute_after_iteration(self, runner):
        for name, scheduler in runner.lr_schedulers.items():
            scheduler.step()
            assert scheduler.last_epoch == runner.iter
            current_lr = runner.optimizers[name].param_groups[0]['lr']
            runner.running_stats.update({f'lr_{name}': current_lr})