# Copyright (c) Facebook, Inc. and its affiliates.

import math
import numpy as np
from unittest import TestCase
import torch
from fvcore.common.param_scheduler import (
    CosineParamScheduler,
    MultiStepParamScheduler,
    StepWithFixedGammaParamScheduler,
)
from torch import nn

from detectron2.solver import LRMultiplier, WarmupParamScheduler, build_lr_scheduler


class TestScheduler(TestCase):
    def test_warmup_multistep(self):
        p = nn.Parameter(torch.zeros(0))
        opt = torch.optim.SGD([p], lr=5)

        multiplier = WarmupParamScheduler(
            MultiStepParamScheduler(
                [1, 0.1, 0.01, 0.001],
                milestones=[10, 15, 20],
                num_updates=30,
            ),
            0.001,
            5 / 30,
        )
        sched = LRMultiplier(opt, multiplier, 30)
        # This is an equivalent of:
        # sched = WarmupMultiStepLR(
        # opt, milestones=[10, 15, 20], gamma=0.1, warmup_factor=0.001, warmup_iters=5)

        p.sum().backward()
        opt.step()

        lrs = [0.005]
        for _ in range(30):
            sched.step()
            lrs.append(opt.param_groups[0]["lr"])
        self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001]))
        self.assertTrue(np.allclose(lrs[5:10], 5.0))
        self.assertTrue(np.allclose(lrs[10:15], 0.5))
        self.assertTrue(np.allclose(lrs[15:20], 0.05))
        self.assertTrue(np.allclose(lrs[20:], 0.005))

    def test_warmup_cosine(self):
        p = nn.Parameter(torch.zeros(0))
        opt = torch.optim.SGD([p], lr=5)
        multiplier = WarmupParamScheduler(
            CosineParamScheduler(1, 0),
            0.001,
            5 / 30,
        )
        sched = LRMultiplier(opt, multiplier, 30)

        p.sum().backward()
        opt.step()
        self.assertEqual(opt.param_groups[0]["lr"], 0.005)
        lrs = [0.005]

        for _ in range(30):
            sched.step()
            lrs.append(opt.param_groups[0]["lr"])
        for idx, lr in enumerate(lrs):
            expected_cosine = 2.5 * (1.0 + math.cos(math.pi * idx / 30))
            if idx >= 5:
                self.assertAlmostEqual(lr, expected_cosine)
            else:
                self.assertNotAlmostEqual(lr, expected_cosine)

    def test_warmup_cosine_end_value(self):
        from detectron2.config import CfgNode, get_cfg

        def _test_end_value(cfg_dict):
            cfg = get_cfg()
            cfg.merge_from_other_cfg(CfgNode(cfg_dict))

            p = nn.Parameter(torch.zeros(0))
            opt = torch.optim.SGD([p], lr=cfg.SOLVER.BASE_LR)

            scheduler = build_lr_scheduler(cfg, opt)

            p.sum().backward()
            opt.step()
            self.assertEqual(
                opt.param_groups[0]["lr"], cfg.SOLVER.BASE_LR * cfg.SOLVER.WARMUP_FACTOR
            )

            lrs = []
            for _ in range(cfg.SOLVER.MAX_ITER):
                scheduler.step()
                lrs.append(opt.param_groups[0]["lr"])

            self.assertAlmostEqual(lrs[-1], cfg.SOLVER.BASE_LR_END)

        _test_end_value(
            {
                "SOLVER": {
                    "LR_SCHEDULER_NAME": "WarmupCosineLR",
                    "MAX_ITER": 100,
                    "WARMUP_ITERS": 10,
                    "WARMUP_FACTOR": 0.1,
                    "BASE_LR": 5.0,
                    "BASE_LR_END": 0.0,
                }
            }
        )

        _test_end_value(
            {
                "SOLVER": {
                    "LR_SCHEDULER_NAME": "WarmupCosineLR",
                    "MAX_ITER": 100,
                    "WARMUP_ITERS": 10,
                    "WARMUP_FACTOR": 0.1,
                    "BASE_LR": 5.0,
                    "BASE_LR_END": 0.5,
                }
            }
        )

    def test_warmup_stepwithfixedgamma(self):
        p = nn.Parameter(torch.zeros(0))
        opt = torch.optim.SGD([p], lr=5)

        multiplier = WarmupParamScheduler(
            StepWithFixedGammaParamScheduler(
                base_value=1.0,
                gamma=0.1,
                num_decays=4,
                num_updates=30,
            ),
            0.001,
            5 / 30,
            rescale_interval=True,
        )
        sched = LRMultiplier(opt, multiplier, 30)

        p.sum().backward()
        opt.step()

        lrs = [0.005]
        for _ in range(29):
            sched.step()
            lrs.append(opt.param_groups[0]["lr"])
        self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001]))
        self.assertTrue(np.allclose(lrs[5:10], 5.0))
        self.assertTrue(np.allclose(lrs[10:15], 0.5))
        self.assertTrue(np.allclose(lrs[15:20], 0.05))
        self.assertTrue(np.allclose(lrs[20:25], 0.005))
        self.assertTrue(np.allclose(lrs[25:], 0.0005))

        # Calling sche.step() after the last training iteration is done will trigger IndexError
        with self.assertRaises(IndexError, msg="list index out of range"):
            sched.step()