File size: 17,684 Bytes
3dd84f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# modified from transformers.optimization
import math
from functools import partial

import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

def _get_constant_lambda(_=None):
    return 1


def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
    """

    Create a schedule with a constant learning rate, using the learning rate set in optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)


def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
    """

    Create a schedule with a constant learning rate that decreases when a metric has stopped improving.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        kwargs (`dict`, *optional*):

            Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`

            for possible parameters.



    Return:

        `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.

    """

    return ReduceLROnPlateau(optimizer, **kwargs)


def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1.0, num_warmup_steps))
    return 1.0


def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
    """

    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate

    increases linearly between 0 and the initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """

    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after

    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        num_training_steps (`int`):

            The total number of training steps.

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_lambda = partial(
        _get_linear_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_cosine_schedule_with_warmup_lr_lambda(

    current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float

):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))


def get_cosine_schedule_with_warmup(

    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1

):
    """

    Create a schedule with a learning rate that decreases following the values of the cosine function between the

    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the

    initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        num_training_steps (`int`):

            The total number of training steps.

        num_cycles (`float`, *optional*, defaults to 0.5):

            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0

            following a half-cosine).

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_lambda = partial(
        _get_cosine_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=num_cycles,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(

    current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int

):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    if progress >= 1.0:
        return 0.0
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))


def get_cosine_with_hard_restarts_schedule_with_warmup(

    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1

):
    """

    Create a schedule with a learning rate that decreases following the values of the cosine function between the

    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases

    linearly between 0 and the initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        num_training_steps (`int`):

            The total number of training steps.

        num_cycles (`int`, *optional*, defaults to 1):

            The number of hard restarts to use.

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    lr_lambda = partial(
        _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=num_cycles,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_polynomial_decay_schedule_with_warmup_lr_lambda(

    current_step: int,

    *,

    num_warmup_steps: int,

    num_training_steps: int,

    lr_end: float,

    power: float,

    lr_init: int,

):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    elif current_step > num_training_steps:
        return lr_end / lr_init  # as LambdaLR multiplies by lr_init
    else:
        lr_range = lr_init - lr_end
        decay_steps = num_training_steps - num_warmup_steps
        pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
        decay = lr_range * pct_remaining**power + lr_end
        return decay / lr_init  # as LambdaLR multiplies by lr_init


def get_polynomial_decay_schedule_with_warmup(

    optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1

):
    """

    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the

    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the

    initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        num_training_steps (`int`):

            The total number of training steps.

        lr_end (`float`, *optional*, defaults to 1e-7):

            The end LR.

        power (`float`, *optional*, defaults to 1.0):

            Power factor.

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT

    implementation at

    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.



    """

    lr_init = optimizer.defaults["lr"]
    if not (lr_init > lr_end):
        raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")

    lr_lambda = partial(
        _get_polynomial_decay_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        lr_end=lr_end,
        power=power,
        lr_init=lr_init,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int = None):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    shift = timescale - num_warmup_steps
    decay = 1.0 / math.sqrt((current_step + shift) / timescale)
    return decay


def get_inverse_sqrt_schedule(

    optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1

):
    """

    Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a

    warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        timescale (`int`, *optional*, defaults to `num_warmup_steps`):

            Time scale.

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """
    # Note: this implementation is adapted from
    # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930

    if timescale is None:
        timescale = num_warmup_steps or 10_000

    lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


def _get_cosine_schedule_with_warmup_lr_lambda(

    current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0

):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
    factor = factor * (1 - min_lr_rate) + min_lr_rate
    return max(0, factor)


def get_cosine_with_min_lr_schedule_with_warmup(

    optimizer: Optimizer,

    num_warmup_steps: int,

    num_training_steps: int,

    num_cycles: float = 0.5,

    last_epoch: int = -1,

    min_lr: float = None,

    min_lr_rate: float = None,

):
    """

    Create a schedule with a learning rate that decreases following the values of the cosine function between the

    initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the

    initial lr set in the optimizer.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        num_training_steps (`int`):

            The total number of training steps.

        num_cycles (`float`, *optional*, defaults to 0.5):

            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0

            following a half-cosine).

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.

        min_lr (`float`, *optional*):

            The minimum learning rate to reach after the cosine schedule.

        min_lr_rate (`float`, *optional*):

            The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    if min_lr is not None and min_lr_rate is not None:
        raise ValueError("Only one of min_lr or min_lr_rate should be set")
    elif min_lr is not None:
        min_lr_rate = min_lr / optimizer.defaults["lr"]
    elif min_lr_rate is None:
        raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")

    lr_lambda = partial(
        _get_cosine_schedule_with_warmup_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=num_cycles,
        min_lr_rate=min_lr_rate,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)


def _get_wsd_scheduler_lambda(

    current_step: int,

    *,

    num_warmup_steps: int,

    num_stable_steps: int,

    num_decay_steps: int,

    num_cycles: float,

    min_lr_ratio: float,

):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    if current_step < num_warmup_steps + num_stable_steps:
        return 1.0
    if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
        progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
        value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
        return (1.0 - min_lr_ratio) * value + min_lr_ratio
    return min_lr_ratio


def get_wsd_schedule(

    optimizer: Optimizer,

    num_warmup_steps: int,

    num_stable_steps: int,

    num_decay_steps: int,

    min_lr_ratio: float = 0,

    num_cycles: float = 0.5,

    last_epoch: int = -1,

):
    """

    Create a schedule with a learning rate that has three stages:

    1. linear increase from 0 to initial lr.

    2. constant lr (equal to initial lr).

    3. decrease following the values of the cosine function between the initial lr set in the optimizer to

       a fraction of initial lr.



    Args:

        optimizer ([`~torch.optim.Optimizer`]):

            The optimizer for which to schedule the learning rate.

        num_warmup_steps (`int`):

            The number of steps for the warmup phase.

        num_stable_steps (`int`):

            The number of steps for the stable phase.

        num_decay_steps (`int`):

            The number of steps for the cosine annealing phase.

        min_lr_ratio (`float`, *optional*, defaults to 0):

            The minimum learning rate as a ratio of the initial learning rate.

        num_cycles (`float`, *optional*, defaults to 0.5):

            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0

            following a half-cosine).

        last_epoch (`int`, *optional*, defaults to -1):

            The index of the last epoch when resuming training.



    Return:

        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """
    lr_lambda = partial(
        _get_wsd_scheduler_lambda,
        num_warmup_steps=num_warmup_steps,
        num_stable_steps=num_stable_steps,
        num_decay_steps=num_decay_steps,
        min_lr_ratio=min_lr_ratio,
        num_cycles=num_cycles,
    )
    return LambdaLR(optimizer, lr_lambda, last_epoch)