Spaces:
Running
Running
File size: 80,183 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 |
import types
import math
from torch import inf
from functools import wraps, partial
import warnings
import weakref
from collections import Counter
from bisect import bisect_right
from .optimizer import Optimizer
__all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
EPOCH_DEPRECATION_WARNING = (
"The epoch parameter in `scheduler.step()` was not necessary and is being "
"deprecated where possible. Please use `scheduler.step()` to step the "
"scheduler. During the deprecation, if epoch is different from None, the "
"closed form is used instead of the new chainable form, where available. "
"Please open an issue if you are unable to replicate your use case: "
"https://github.com/pytorch/pytorch/issues/new/choose."
)
def _check_verbose_deprecated_warning(verbose):
"""Raises a warning when verbose is not the default value."""
if verbose != "deprecated":
warnings.warn("The verbose parameter is deprecated. Please use get_last_lr() "
"to access the learning rate.", UserWarning)
return verbose
return False
class LRScheduler:
def __init__(self, optimizer, last_epoch=-1, verbose="deprecated"):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
f"in param_groups[{i}] when resuming an optimizer")
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, '_with_counter', False):
# `optimizer.step()` has already been replaced, return.
return method
# Keep a weak reference to the optimizer instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method
@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)
# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step)
self.verbose = _check_verbose_deprecated_warning(verbose)
self._initial_step()
def _initial_step(self):
"""Initialize step counts and performs a step"""
self.optimizer._step_count = 0
self._step_count = 0
self.step()
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
"""
return self._last_lr
def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError
def print_lr(self, is_verbose, group, lr, epoch=None):
"""Display the current learning rate.
"""
if is_verbose:
if epoch is None:
print(f'Adjusting learning rate of group {group} to {lr:.4e}.')
else:
epoch_str = ("%.2f" if isinstance(epoch, float) else
"%.5d") % epoch
print(f'Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.4e}.')
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
# Including _LRScheduler for backwards compatibility
# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
class _LRScheduler(LRScheduler):
pass
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
class LambdaLR(LRScheduler):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer has two groups.
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose="deprecated"):
self.optimizer = optimizer
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
self.lr_lambdas = list(lr_lambda)
super().__init__(optimizer, last_epoch, verbose)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
"""
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
return state_dict
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop('lr_lambdas')
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict['lr_lambdas'] = lr_lambdas
for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return [base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
class MultiplicativeLR(LRScheduler):
"""Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> lmbda = lambda epoch: 0.95
>>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose="deprecated"):
self.optimizer = optimizer
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError(f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}")
self.lr_lambdas = list(lr_lambda)
super().__init__(optimizer, last_epoch, verbose)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
"""
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
return state_dict
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop('lr_lambdas')
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict['lr_lambdas'] = lr_lambdas
for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch > 0:
return [group['lr'] * lmbda(self.last_epoch)
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
else:
return [group['lr'] for group in self.optimizer.param_groups]
class StepLR(LRScheduler):
"""Decays the learning rate of each parameter group by gamma every
step_size epochs. Notice that such decay can happen simultaneously with
other changes to the learning rate from outside this scheduler. When
last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ...
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose="deprecated"):
self.step_size = step_size
self.gamma = gamma
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs]
class MultiStepLR(LRScheduler):
"""Decays the learning rate of each parameter group by gamma once the
number of epoch reaches one of the milestones. Notice that such decay can
happen simultaneously with other changes to the learning rate from outside
this scheduler. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose="deprecated"):
self.milestones = Counter(milestones)
self.gamma = gamma
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
milestones = sorted(self.milestones.elements())
return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
for base_lr in self.base_lrs]
class ConstantLR(LRScheduler):
"""Multiply the learning rate of each parameter group by a small constant factor until the
number of epoch reaches a pre-defined milestone: total_iters.
Notice that such multiplication of the small constant factor can
happen simultaneously with other changes to the learning rate from outside this scheduler.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor.
Default: 5.
last_epoch (int): The index of the last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.025 if epoch == 0
>>> # lr = 0.025 if epoch == 1
>>> # lr = 0.025 if epoch == 2
>>> # lr = 0.025 if epoch == 3
>>> # lr = 0.05 if epoch >= 4
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=4)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose="deprecated"):
if factor > 1.0 or factor < 0:
raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
self.factor = factor
self.total_iters = total_iters
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] * self.factor for group in self.optimizer.param_groups]
if self.last_epoch != self.total_iters:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
for base_lr in self.base_lrs]
class LinearLR(LRScheduler):
"""Decays the learning rate of each parameter group by linearly changing small
multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
Notice that such decay can happen simultaneously with other changes to the learning rate
from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply learning rate in the first epoch.
The multiplication factor changes towards end_factor in the following epochs.
Default: 1./3.
end_factor (float): The number we multiply learning rate at the end of linear changing
process. Default: 1.0.
total_iters (int): The number of iterations that multiplicative factor reaches to 1.
Default: 5.
last_epoch (int): The index of the last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.025 if epoch == 0
>>> # lr = 0.03125 if epoch == 1
>>> # lr = 0.0375 if epoch == 2
>>> # lr = 0.04375 if epoch == 3
>>> # lr = 0.05 if epoch >= 4
>>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
verbose="deprecated"):
if start_factor > 1.0 or start_factor <= 0:
raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
if end_factor > 1.0 or end_factor < 0:
raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
self.start_factor = start_factor
self.end_factor = end_factor
self.total_iters = total_iters
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
if self.last_epoch > self.total_iters:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
(self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * (self.start_factor +
(self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
for base_lr in self.base_lrs]
class ExponentialLR(LRScheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
"""
def __init__(self, optimizer, gamma, last_epoch=-1, verbose="deprecated"):
self.gamma = gamma
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** self.last_epoch
for base_lr in self.base_lrs]
class SequentialLR(LRScheduler):
"""Receives the list of schedulers that is expected to be called sequentially during
optimization process and milestone points that provides exact intervals to reflect
which scheduler is supposed to be called at a given epoch.
Args:
optimizer (Optimizer): Wrapped optimizer.
schedulers (list): List of chained schedulers.
milestones (list): List of integers that reflects milestone points.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): Does nothing.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.1 if epoch == 0
>>> # lr = 0.1 if epoch == 1
>>> # lr = 0.9 if epoch == 2
>>> # lr = 0.81 if epoch == 3
>>> # lr = 0.729 if epoch == 4
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose="deprecated"):
for scheduler_idx in range(len(schedulers)):
if schedulers[scheduler_idx].optimizer != optimizer:
raise ValueError(
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
)
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
raise ValueError(
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
f"got schedulers at index {0} and {scheduler_idx} to be different."
)
if (len(milestones) != len(schedulers) - 1):
raise ValueError(
"Sequential Schedulers expects number of schedulers provided to be one more "
f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
f"number of milestones to be equal to {len(milestones)}"
)
_check_verbose_deprecated_warning(verbose)
self._schedulers = schedulers
self._milestones = milestones
self.last_epoch = last_epoch + 1
self.optimizer = optimizer
# Reset learning rates back to initial values
for group in self.optimizer.param_groups:
group["lr"] = group["initial_lr"]
# "Undo" the step performed by other schedulers
for scheduler in self._schedulers:
scheduler.last_epoch -= 1
# Perform the initial step for only the first scheduler
self._schedulers[0]._initial_step()
self._last_lr = schedulers[0].get_last_lr()
def step(self):
self.last_epoch += 1
idx = bisect_right(self._milestones, self.last_epoch)
scheduler = self._schedulers[idx]
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
scheduler.step(0)
else:
scheduler.step()
self._last_lr = scheduler.get_last_lr()
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The wrapped scheduler states will also be saved.
"""
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
state_dict['_schedulers'] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers):
state_dict['_schedulers'][idx] = s.state_dict()
return state_dict
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
_schedulers = state_dict.pop('_schedulers')
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict['_schedulers'] = _schedulers
for idx, s in enumerate(_schedulers):
self._schedulers[idx].load_state_dict(s)
class PolynomialLR(LRScheduler):
"""Decays the learning rate of each parameter group using a polynomial function
in the given total_iters. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
power (float): The power of the polynomial. Default: 1.0.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP("undefined vars")
>>> # Assuming optimizer uses lr = 0.001 for all groups
>>> # lr = 0.001 if epoch == 0
>>> # lr = 0.00075 if epoch == 1
>>> # lr = 0.00050 if epoch == 2
>>> # lr = 0.00025 if epoch == 3
>>> # lr = 0.0 if epoch >= 4
>>> scheduler = PolynomialLR(optimizer, total_iters=4, power=1.0)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose="deprecated"):
self.total_iters = total_iters
self.power = power
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [
(
base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
)
for base_lr in self.base_lrs
]
class CosineAnnealingLR(LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
is defined recursively, the learning rate can be simultaneously modified
outside this scheduler by other operators. If the learning rate is set
solely by this scheduler, the learning rate at each step becomes:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose="deprecated"):
self.T_max = T_max
self.eta_min = eta_min
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] for group in self.optimizer.param_groups]
elif self._step_count == 1 and self.last_epoch > 0:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
for base_lr, group in
zip(self.base_lrs, self.optimizer.param_groups)]
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return [group['lr'] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in
zip(self.base_lrs, self.optimizer.param_groups)]
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]
class ChainedScheduler(LRScheduler):
"""Chains list of learning rate schedulers. It takes a list of chainable learning
rate schedulers and performs consecutive step() functions belonging to them by just
one call.
Args:
schedulers (list): List of chained schedulers.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 1. for all groups
>>> # lr = 0.09 if epoch == 0
>>> # lr = 0.081 if epoch == 1
>>> # lr = 0.729 if epoch == 2
>>> # lr = 0.6561 if epoch == 3
>>> # lr = 0.59049 if epoch >= 4
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = ChainedScheduler([scheduler1, scheduler2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, schedulers):
for scheduler_idx in range(1, len(schedulers)):
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
raise ValueError(
"ChainedScheduler expects all schedulers to belong to the same optimizer, but "
f"got schedulers at index {0} and {scheduler_idx} to be different"
)
self._schedulers = list(schedulers)
self.optimizer = schedulers[0].optimizer
self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
def step(self):
for scheduler in self._schedulers:
scheduler.step()
self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The wrapped scheduler states will also be saved.
"""
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
state_dict['_schedulers'] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers):
state_dict['_schedulers'][idx] = s.state_dict()
return state_dict
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
_schedulers = state_dict.pop('_schedulers')
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict['_schedulers'] = _schedulers
for idx, s in enumerate(_schedulers):
self._schedulers[idx].load_state_dict(s)
class ReduceLROnPlateau(LRScheduler):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This scheduler reads a metrics
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Args:
optimizer (Optimizer): Wrapped optimizer.
mode (str): One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor (float): Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience (int): The number of allowed epochs with no improvement after
which the learning rate will be reduced.
For example, consider the case of having no patience (`patience = 0`).
In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
In the second epoch, if the performance is worse than the baseline,
we have what is considered an intolerable epoch.
Since the count of intolerable epochs (1) is greater than the patience level (0),
the learning rate is reduced at the end of this epoch.
From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
if the performance is worse than the baseline. If the performance improves or remains the same,
the learning rate is not adjusted.
Default: 10.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr (float or list): A scalar or a list of scalars. A
lower bound on the learning rate of all param groups
or each group respectively. Default: 0.
eps (float): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
>>> for epoch in range(10):
>>> train(...)
>>> val_loss = validate(...)
>>> # Note that step should be called after validate()
>>> scheduler.step(val_loss)
"""
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
threshold=1e-4, threshold_mode='rel', cooldown=0,
min_lr=0, eps=1e-8, verbose="deprecated"):
if factor >= 1.0:
raise ValueError('Factor should be < 1.0.')
self.factor = factor
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
self.optimizer = optimizer
if isinstance(min_lr, (list, tuple)):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
self.min_lrs = list(min_lr)
else:
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.verbose = _check_verbose_deprecated_warning(verbose)
self.cooldown = cooldown
self.cooldown_counter = 0
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.best = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.eps = eps
self.last_epoch = 0
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
self._init_is_better(mode=mode, threshold=threshold,
threshold_mode=threshold_mode)
self._reset()
def _reset(self):
"""Resets num_bad_epochs counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0
def step(self, metrics, epoch=None):
# convert `metrics` to float, in case it's a zero-dim Tensor
current = float(metrics)
if epoch is None:
epoch = self.last_epoch + 1
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
def _reduce_lr(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group['lr'])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group['lr'] = new_lr
@property
def in_cooldown(self):
return self.cooldown_counter > 0
def is_better(self, a, best):
if self.mode == 'min' and self.threshold_mode == 'rel':
rel_epsilon = 1. - self.threshold
return a < best * rel_epsilon
elif self.mode == 'min' and self.threshold_mode == 'abs':
return a < best - self.threshold
elif self.mode == 'max' and self.threshold_mode == 'rel':
rel_epsilon = self.threshold + 1.
return a > best * rel_epsilon
else: # mode == 'max' and epsilon_mode == 'abs':
return a > best + self.threshold
def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if threshold_mode not in {'rel', 'abs'}:
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
if mode == 'min':
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
class CyclicLR(LRScheduler):
r"""Sets the learning rate of each parameter group according to
cyclical learning rate policy (CLR). The policy cycles the learning
rate between two boundaries with a constant frequency, as detailed in
the paper `Cyclical Learning Rates for Training Neural Networks`_.
The distance between the two boundaries can be scaled on a per-iteration
or per-cycle basis.
Cyclical learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This class has three built-in policies, as put forth in the paper:
* "triangular": A basic triangular cycle without amplitude scaling.
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
at each cycle iteration.
This implementation was adapted from the github repo: `bckenstler/CLR`_
Args:
optimizer (Optimizer): Wrapped optimizer.
base_lr (float or list): Initial learning rate which is the
lower boundary in the cycle for each parameter group.
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_lr - base_lr).
The lr at any cycle is the sum of base_lr
and some scaling of the amplitude; therefore
max_lr may not actually be reached depending on
scaling function.
step_size_up (int): Number of training iterations in the
increasing half of a cycle. Default: 2000
step_size_down (int): Number of training iterations in the
decreasing half of a cycle. If step_size_down is None,
it is set to step_size_up. Default: None
mode (str): One of {triangular, triangular2, exp_range}.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
Default: 'triangular'
gamma (float): Constant in 'exp_range' scaling function:
gamma**(cycle iterations)
Default: 1.0
scale_fn (function): Custom scaling policy defined by a single
argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0.
If specified, then 'mode' is ignored.
Default: None
scale_mode (str): {'cycle', 'iterations'}.
Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training
iterations since start of cycle).
Default: 'cycle'
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.8
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
The momentum at any cycle is the difference of max_momentum
and some scaling of the amplitude; therefore
base_momentum may not actually be reached depending on
scaling function. Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.9
last_epoch (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
"""
def __init__(self,
optimizer,
base_lr,
max_lr,
step_size_up=2000,
step_size_down=None,
mode='triangular',
gamma=1.,
scale_fn=None,
scale_mode='cycle',
cycle_momentum=True,
base_momentum=0.8,
max_momentum=0.9,
last_epoch=-1,
verbose="deprecated"):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
self.optimizer = optimizer
base_lrs = self._format_param('base_lr', optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
group['lr'] = lr
self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
step_size_up = float(step_size_up)
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size
if mode not in ['triangular', 'triangular2', 'exp_range'] \
and scale_fn is None:
raise ValueError('mode is invalid and scale_fn is None')
self.mode = mode
self.gamma = gamma
self._scale_fn_ref = None
self._scale_fn_custom = scale_fn
self.scale_mode = scale_mode
self._init_scale_fn()
self.cycle_momentum = cycle_momentum
if cycle_momentum:
if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
self.use_beta1 = 'betas' in self.optimizer.defaults
self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
if self.use_beta1:
group['betas'] = (m_momentum, *group['betas'][1:])
else:
group['momentum'] = m_momentum
group['max_momentum'] = m_momentum
group['base_momentum'] = b_momentum
super().__init__(optimizer, last_epoch, verbose)
self.base_lrs = base_lrs
def _init_scale_fn(self):
if self._scale_fn_custom is not None:
return
if self.mode == 'triangular':
self._scale_fn_ref = self._triangular_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self._scale_fn_ref = self._triangular2_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
self.scale_mode = 'iterations'
def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
return param
else:
return [param] * len(optimizer.param_groups)
def scale_fn(self, x):
if self._scale_fn_custom is not None:
return self._scale_fn_custom(x)
else:
return self._scale_fn_ref(x) # static method
@staticmethod
def _triangular_scale_fn(x):
return 1.
@staticmethod
def _triangular2_scale_fn(x):
return 1 / (2. ** (x - 1))
@staticmethod
def _exp_range_scale_fn(gamma, x):
return gamma ** x
def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_epoch` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
cycle = math.floor(1 + self.last_epoch / self.total_size)
x = 1. + self.last_epoch / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = []
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
base_height = (max_lr - base_lr) * scale_factor
if self.scale_mode == 'cycle':
lr = base_lr + base_height * self.scale_fn(cycle)
else:
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
lrs.append(lr)
if self.cycle_momentum:
momentums = []
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
base_height = (max_momentum - base_momentum) * scale_factor
if self.scale_mode == 'cycle':
momentum = max_momentum - base_height * self.scale_fn(cycle)
else:
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
momentums.append(momentum)
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
if self.use_beta1:
param_group['betas'] = (momentum, *param_group['betas'][1:])
else:
param_group['momentum'] = momentum
return lrs
def state_dict(self):
state = super().state_dict()
# We are dropping the `_scale_fn_ref` attribute because it is a
# `weakref.WeakMethod` and can't be pickled.
state.pop('_scale_fn_ref')
fn = state.pop('_scale_fn_custom')
state['_scale_fn_custom'] = None
if fn is not None and not isinstance(fn, types.FunctionType):
# The _scale_fn_custom will only be saved if it is a callable object
# and not if it is a function or lambda.
state['_scale_fn_custom'] = fn.__dict__.copy()
return state
def load_state_dict(self, state_dict):
fn = state_dict.pop('_scale_fn_custom')
super().load_state_dict(state_dict)
if fn is not None:
self._scale_fn_custom.__dict__.update(fn)
self._init_scale_fn()
class CosineAnnealingWarmRestarts(LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_0 (int): Number of iterations for the first restart.
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose="deprecated"):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
if not isinstance(eta_min, (float, int)):
raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = last_epoch
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
"""Step could be called after every batch update
Example:
>>> # xdoctest: +SKIP("Undefined vars")
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> iters = len(dataloader)
>>> for epoch in range(20):
>>> for i, sample in enumerate(dataloader):
>>> inputs, labels = sample['inputs'], sample['labels']
>>> optimizer.zero_grad()
>>> outputs = net(inputs)
>>> loss = criterion(outputs, labels)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step(epoch + i / iters)
This function can be called in an interleaved way.
Example:
>>> # xdoctest: +SKIP("Undefined vars")
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> for epoch in range(20):
>>> scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
"""
if epoch is None and self.last_epoch < 0:
epoch = 0
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult
else:
if epoch < 0:
raise ValueError(f"Expected non-negative epoch, but got {epoch}")
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = math.floor(epoch)
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
return self
with _enable_get_lr_call(self):
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
param_group, lr = data
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
class OneCycleLR(LRScheduler):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
claims that "unpublished work has shown even better results by using only two phases". To
mimic the behaviour of the original paper instead, set ``three_phase=True``.
Args:
optimizer (Optimizer): Wrapped optimizer.
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
if a value is not provided here, then it must be inferred by providing
a value for epochs and steps_per_epoch.
Default: None
epochs (int): The number of epochs to train for. This is used along
with steps_per_epoch in order to infer the total number of steps in the cycle
if a value for total_steps is not provided.
Default: None
steps_per_epoch (int): The number of steps per epoch to train for. This is
used along with epochs in order to infer the total number of steps in the
cycle if a value for total_steps is not provided.
Default: None
pct_start (float): The percentage of the cycle (in number of steps) spent
increasing the learning rate.
Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
linear annealing.
Default: 'cos'
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.85
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.95
div_factor (float): Determines the initial learning rate via
initial_lr = max_lr/div_factor
Default: 25
final_div_factor (float): Determines the minimum learning rate via
min_lr = initial_lr/final_div_factor
Default: 1e4
three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
learning rate according to 'final_div_factor' instead of modifying the second
phase (the first two phases will be symmetrical about the step indicated by
'pct_start').
last_epoch (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. deprecated:: 2.2
``verbose`` is deprecated. Please use ``get_last_lr()`` to access the
learning rate.
Example:
>>> # xdoctest: +SKIP
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> optimizer.step()
>>> scheduler.step()
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""
def __init__(self,
optimizer,
max_lr,
total_steps=None,
epochs=None,
steps_per_epoch=None,
pct_start=0.3,
anneal_strategy='cos',
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.,
final_div_factor=1e4,
three_phase=False,
last_epoch=-1,
verbose="deprecated"):
# Validate optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')
self.optimizer = optimizer
# Validate total_steps
if total_steps is None and epochs is None and steps_per_epoch is None:
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
elif total_steps is not None:
if total_steps <= 0 or not isinstance(total_steps, int):
raise ValueError(f"Expected positive integer total_steps, but got {total_steps}")
self.total_steps = total_steps
else:
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError(f"Expected positive integer epochs, but got {epochs}")
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError(f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}")
self.total_steps = epochs * steps_per_epoch
if three_phase:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
'start_lr': 'initial_lr',
'end_lr': 'max_lr',
'start_momentum': 'max_momentum',
'end_momentum': 'base_momentum',
},
{
'end_step': float(2 * pct_start * self.total_steps) - 2,
'start_lr': 'max_lr',
'end_lr': 'initial_lr',
'start_momentum': 'base_momentum',
'end_momentum': 'max_momentum',
},
{
'end_step': self.total_steps - 1,
'start_lr': 'initial_lr',
'end_lr': 'min_lr',
'start_momentum': 'max_momentum',
'end_momentum': 'max_momentum',
},
]
else:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
'start_lr': 'initial_lr',
'end_lr': 'max_lr',
'start_momentum': 'max_momentum',
'end_momentum': 'base_momentum',
},
{
'end_step': self.total_steps - 1,
'start_lr': 'max_lr',
'end_lr': 'min_lr',
'start_momentum': 'base_momentum',
'end_momentum': 'max_momentum',
},
]
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}")
# Validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError(f"anneal_strategy must by one of 'cos' or 'linear', instead got {anneal_strategy}")
elif anneal_strategy == 'cos':
self.anneal_func = self._annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = self._annealing_linear
# Initialize learning rate variables
max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
if last_epoch == -1:
for idx, group in enumerate(self.optimizer.param_groups):
group['initial_lr'] = max_lrs[idx] / div_factor
group['max_lr'] = max_lrs[idx]
group['min_lr'] = group['initial_lr'] / final_div_factor
# Initialize momentum variables
self.cycle_momentum = cycle_momentum
if self.cycle_momentum:
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
self.use_beta1 = 'betas' in self.optimizer.defaults
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
if self.use_beta1:
group['betas'] = (m_momentum, *group['betas'][1:])
else:
group['momentum'] = m_momentum
group['max_momentum'] = m_momentum
group['base_momentum'] = b_momentum
super().__init__(optimizer, last_epoch, verbose)
def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError(f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}")
return param
else:
return [param] * len(optimizer.param_groups)
@staticmethod
def _annealing_cos(start, end, pct):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
@staticmethod
def _annealing_linear(start, end, pct):
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return (end - start) * pct + start
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
lrs = []
step_num = self.last_epoch
if step_num > self.total_steps:
raise ValueError("Tried to step {} times. The specified number of total steps is {}"
.format(step_num, self.total_steps))
for group in self.optimizer.param_groups:
start_step = 0
for i, phase in enumerate(self._schedule_phases):
end_step = phase['end_step']
if step_num <= end_step or i == len(self._schedule_phases) - 1:
pct = (step_num - start_step) / (end_step - start_step)
computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
if self.cycle_momentum:
computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
break
start_step = phase['end_step']
lrs.append(computed_lr)
if self.cycle_momentum:
if self.use_beta1:
group['betas'] = (computed_momentum, *group['betas'][1:])
else:
group['momentum'] = computed_momentum
return lrs
|