# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from accelerate.accelerator import Accelerator, GradientAccumulationPlugin
from accelerate.state import GradientState
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import DistributedType, set_seed


def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs):
    for param, grad_param in zip(model_a.parameters(), model_b.parameters()):
        if not param.requires_grad:
            continue
        if not did_step:
            # Grads should not be in sync
            assert (
                torch.allclose(param.grad, grad_param.grad, **kwargs) is False
            ), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
        else:
            # Grads should be in sync
            assert (
                torch.allclose(param.grad, grad_param.grad, **kwargs) is True
            ), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"


def step_model(model, input, target, accelerator, do_backward=True):
    model.train()
    output = model(input)
    loss = F.mse_loss(output, target.to(output.device))
    if not do_backward:
        loss /= accelerator.gradient_accumulation_steps
        loss.backward()
    else:
        accelerator.backward(loss)


def get_training_setup(accelerator, sched=False):
    "Returns everything needed to perform basic training"
    set_seed(42)
    model = RegressionModel()
    ddp_model = deepcopy(model)
    dset = RegressionDataset(length=80)
    dataloader = DataLoader(dset, batch_size=16)
    model.to(accelerator.device)
    if sched:
        opt = AdamW(params=model.parameters(), lr=1e-3)
        ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3)
        sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65)
        ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65)
    # Make a copy of `model`
    if sched:
        ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader)
    else:
        ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
    if sched:
        return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched)
    return model, ddp_model, dataloader


def test_noop_sync(accelerator):
    # Test when on a single CPU or GPU that the context manager does nothing
    model, ddp_model, dataloader = get_training_setup(accelerator)
    # Use a single batch
    ddp_input, ddp_target = next(iter(dataloader)).values()
    for iteration in range(3):
        # Gather the distributed inputs and targs for the base model
        input, target = accelerator.gather((ddp_input, ddp_target))
        input, target = input.to(accelerator.device), target.to(accelerator.device)
        # Perform our initial ground truth step in non "DDP"
        step_model(model, input, target, accelerator)
        # Do "gradient accumulation" (noop)
        if iteration % 2 == 0:
            # Accumulate grads locally
            with accelerator.no_sync(ddp_model):
                step_model(ddp_model, ddp_input, ddp_target, accelerator)
        else:
            # Sync grads
            step_model(ddp_model, ddp_input, ddp_target, accelerator)

        # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
        check_model_parameters(model, ddp_model, True, iteration)
        for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
            if not param.requires_grad:
                continue
            assert torch.allclose(
                param.grad, ddp_param.grad
            ), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"

        # Shuffle ddp_input on each iteration
        torch.manual_seed(1337 + iteration)
        ddp_input = ddp_input[torch.randperm(len(ddp_input))]


def test_distributed_sync(accelerator):
    # Test on distributed setup that context manager behaves properly
    model, ddp_model, dataloader = get_training_setup(accelerator)
    # Use a single batch
    ddp_input, ddp_target = next(iter(dataloader)).values()
    for iteration in range(3):
        # Gather the distributed inputs and targs for the base model
        input, target = accelerator.gather((ddp_input, ddp_target))
        input, target = input.to(accelerator.device), target.to(accelerator.device)
        # Perform our initial ground truth step in non "DDP"
        step_model(model, input, target, accelerator)
        # Do "gradient accumulation" (noop)
        if iteration % 2 == 0:
            # Accumulate grads locally
            with accelerator.no_sync(ddp_model):
                step_model(ddp_model, ddp_input, ddp_target, accelerator)
        else:
            # Sync grads
            step_model(ddp_model, ddp_input, ddp_target, accelerator)

        # DDP model and model should only be in sync when not (iteration % 2 == 0)
        for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
            if not param.requires_grad:
                continue
            if iteration % 2 == 0:
                # Grads should not be in sync
                assert (
                    torch.allclose(param.grad, ddp_param.grad) is False
                ), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
            else:
                # Grads should be in sync
                assert (
                    torch.allclose(param.grad, ddp_param.grad) is True
                ), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"

        # Shuffle ddp_input on each iteration
        torch.manual_seed(1337 + iteration)
        ddp_input = ddp_input[torch.randperm(len(ddp_input))]


def test_distributed_sync_multiple_fwd(accelerator):
    # Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards
    model, ddp_model, dataloader = get_training_setup(accelerator)
    # Do multiple forwards
    losses = []
    num_iterations = 3
    for iteration in range(num_iterations):
        ddp_input, ddp_target = next(iter(dataloader)).values()

        # Gather the distributed inputs and targs for the base model
        input, target = accelerator.gather((ddp_input, ddp_target))
        input, target = input.to(accelerator.device), target.to(accelerator.device)

        # Perform our initial ground truth step in non "DDP"
        step_model(model, input, target, accelerator)

        # Accumulate grads locally
        with accelerator.no_sync(ddp_model):
            ddp_output = ddp_model(ddp_input)
            loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device))
            losses.append(loss)

    # Do multiple backwards and sync only at the last backward
    for iteration in range(num_iterations):
        loss = losses[iteration]

        if iteration < num_iterations - 1:
            # Accumulate grads locally
            accelerator.backward(loss)

            # DDP model and model should only be in sync after last backward
            for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
                if not param.requires_grad:
                    continue
                # Grads should not be in sync
                assert (
                    torch.allclose(param.grad, ddp_param.grad) is False
                ), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"

        else:
            # Sync grads if last backward
            with accelerator.trigger_sync_in_backward(ddp_model):
                accelerator.backward(loss)

            # DDP model and model should only be in sync after last backward
            for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
                if not param.requires_grad:
                    continue
                # Grads should be in sync
                assert (
                    torch.allclose(param.grad, ddp_param.grad) is True
                ), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"


def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False):
    gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
    accelerator = Accelerator(
        split_batches=split_batches,
        dispatch_batches=dispatch_batches,
        gradient_accumulation_plugin=gradient_accumulation_plugin,
    )
    # Test that context manager behaves properly
    model, ddp_model, dataloader = get_training_setup(accelerator)
    for iteration, batch in enumerate(dataloader):
        ddp_input, ddp_target = batch.values()
        # Gather the distributed inputs and targs for the base model
        input, target = accelerator.gather((ddp_input, ddp_target))
        input, target = input.to(accelerator.device), target.to(accelerator.device)
        # Perform our initial ground truth step in non "DDP"
        step_model(model, input, target, accelerator, False)
        # Do "gradient accumulation" (noop)
        with accelerator.accumulate(ddp_model):
            step_model(ddp_model, ddp_input, ddp_target, accelerator)

        # DDP model and model should only be in sync when not (iteration % 2 == 0)
        for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
            if not param.requires_grad:
                continue
            if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch:
                # Grads should be in sync
                assert (
                    torch.allclose(param.grad, ddp_param.grad) is True
                ), f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
            else:
                # Grads should not be in sync
                assert (
                    torch.allclose(param.grad, ddp_param.grad) is False
                ), f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"

        # Shuffle ddp_input on each iteration
        torch.manual_seed(1337 + iteration)
        ddp_input = ddp_input[torch.randperm(len(ddp_input))]
    GradientState._reset_state()


def test_gradient_accumulation_with_opt_and_scheduler(
    split_batches=False, dispatch_batches=False, sync_each_batch=False
):
    gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
    accelerator = Accelerator(
        split_batches=split_batches,
        dispatch_batches=dispatch_batches,
        gradient_accumulation_plugin=gradient_accumulation_plugin,
    )
    # Test that context manager behaves properly
    model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)
    for iteration, batch in enumerate(dataloader):
        ddp_input, ddp_target = batch.values()
        # Gather the distributed inputs and targs for the base model
        input, target = accelerator.gather((ddp_input, ddp_target))
        input, target = input.to(accelerator.device), target.to(accelerator.device)
        # Perform our initial ground truth step in non "DDP"
        model.train()
        ddp_model.train()
        step_model(model, input, target, accelerator, False)
        opt.step()

        if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch:
            if split_batches:
                sched.step()
            else:
                for _ in range(accelerator.num_processes):
                    sched.step()

        # Perform gradient accumulation under wrapper
        with accelerator.accumulate(ddp_model):
            step_model(ddp_model, ddp_input, ddp_target, accelerator)
            ddp_opt.step()
            ddp_sched.step()

        # Learning rates should be the same
        assert (
            opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"]
        ), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n'
        did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch
        if accelerator.num_processes > 1:
            check_model_parameters(
                model,
                ddp_model,
                did_step,
                iteration,
                rtol=1e-3,  # somehow needs a relative tolerance
            )

        if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch:
            opt.zero_grad()  # needs to be guarded by logic as to when we should zero grads
        ddp_opt.zero_grad()

        # Shuffle ddp_input on each iteration
        torch.manual_seed(1337 + iteration)
    GradientState._reset_state()


def test_dataloader_break():
    accelerator = Accelerator()

    first_dset = RegressionDataset(length=80)
    first_dataloader = DataLoader(first_dset, batch_size=16)
    second_dset = RegressionDataset(length=96)
    second_dataloader = DataLoader(second_dset, batch_size=16)
    first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
    assert accelerator.gradient_state.active_dataloader is None
    for iteration, _ in enumerate(first_dataloader):
        assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
        if iteration < len(first_dataloader) - 1:
            assert not accelerator.gradient_state.end_of_dataloader
            if iteration == 1:
                for batch_num, _ in enumerate(second_dataloader):
                    assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader)
                    if batch_num < len(second_dataloader) - 1:
                        assert not accelerator.gradient_state.end_of_dataloader
                    else:
                        assert accelerator.gradient_state.end_of_dataloader
        else:
            assert accelerator.gradient_state.end_of_dataloader
    assert accelerator.gradient_state.active_dataloader is None


def main():
    accelerator = Accelerator()
    state = accelerator.state
    if state.local_process_index == 0:
        print("**Test `accumulate` gradient accumulation with dataloader break**")
    if state.distributed_type != DistributedType.XLA:
        test_dataloader_break()
    if state.distributed_type == DistributedType.NO:
        if state.local_process_index == 0:
            print("**Test NOOP `no_sync` context manager**")
        test_noop_sync(accelerator)
    if state.distributed_type in (
        DistributedType.MULTI_GPU,
        DistributedType.MULTI_NPU,
        DistributedType.MULTI_MLU,
        DistributedType.MULTI_CPU,
    ):
        if state.local_process_index == 0:
            print("**Test Distributed `no_sync` context manager**")
        test_distributed_sync(accelerator)
        if state.local_process_index == 0:
            print("**Test Distributed `no_sync` context manager with multiple forwards**")
        test_distributed_sync_multiple_fwd(accelerator)
    if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU):
        for split_batch in [True, False]:
            for dispatch_batches in [True, False]:
                for sync_each_batch in [True, False]:
                    if state.local_process_index == 0:
                        print(
                            "**Test `accumulate` gradient accumulation, ",
                            f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
                        )
                    test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch)

    # Currently will break on torch 2.0 +, need to investigate why
    if state.local_process_index == 0:
        print(
            "**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
            "`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**",
        )
    test_gradient_accumulation_with_opt_and_scheduler()
    if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU):
        for split_batch in [True, False]:
            for dispatch_batches in [True, False]:
                for sync_each_batch in [True, False]:
                    if not split_batch and not dispatch_batches and not sync_each_batch:
                        continue
                    if state.local_process_index == 0:
                        print(
                            "**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
                            f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
                        )
                    test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()