Draken007's picture
Upload 7228 files
2a0bc63 verified
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from accelerate.accelerator import Accelerator, GradientAccumulationPlugin
from accelerate.state import GradientState
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import DistributedType, set_seed
def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs):
for param, grad_param in zip(model_a.parameters(), model_b.parameters()):
if not param.requires_grad:
continue
if not did_step:
# Grads should not be in sync
assert (
torch.allclose(param.grad, grad_param.grad, **kwargs) is False
), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
else:
# Grads should be in sync
assert (
torch.allclose(param.grad, grad_param.grad, **kwargs) is True
), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"
def step_model(model, input, target, accelerator, do_backward=True):
model.train()
output = model(input)
loss = F.mse_loss(output, target.to(output.device))
if not do_backward:
loss /= accelerator.gradient_accumulation_steps
loss.backward()
else:
accelerator.backward(loss)
def get_training_setup(accelerator, sched=False):
"Returns everything needed to perform basic training"
set_seed(42)
model = RegressionModel()
ddp_model = deepcopy(model)
dset = RegressionDataset(length=80)
dataloader = DataLoader(dset, batch_size=16)
model.to(accelerator.device)
if sched:
opt = AdamW(params=model.parameters(), lr=1e-3)
ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3)
sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65)
ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65)
# Make a copy of `model`
if sched:
ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader)
else:
ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
if sched:
return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched)
return model, ddp_model, dataloader
def test_noop_sync(accelerator):
# Test when on a single CPU or GPU that the context manager does nothing
model, ddp_model, dataloader = get_training_setup(accelerator)
# Use a single batch
ddp_input, ddp_target = next(iter(dataloader)).values()
for iteration in range(3):
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
step_model(model, input, target, accelerator)
# Do "gradient accumulation" (noop)
if iteration % 2 == 0:
# Accumulate grads locally
with accelerator.no_sync(ddp_model):
step_model(ddp_model, ddp_input, ddp_target, accelerator)
else:
# Sync grads
step_model(ddp_model, ddp_input, ddp_target, accelerator)
# Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
check_model_parameters(model, ddp_model, True, iteration)
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
assert torch.allclose(
param.grad, ddp_param.grad
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(len(ddp_input))]
def test_distributed_sync(accelerator):
# Test on distributed setup that context manager behaves properly
model, ddp_model, dataloader = get_training_setup(accelerator)
# Use a single batch
ddp_input, ddp_target = next(iter(dataloader)).values()
for iteration in range(3):
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
step_model(model, input, target, accelerator)
# Do "gradient accumulation" (noop)
if iteration % 2 == 0:
# Accumulate grads locally
with accelerator.no_sync(ddp_model):
step_model(ddp_model, ddp_input, ddp_target, accelerator)
else:
# Sync grads
step_model(ddp_model, ddp_input, ddp_target, accelerator)
# DDP model and model should only be in sync when not (iteration % 2 == 0)
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
if iteration % 2 == 0:
# Grads should not be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is False
), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
else:
# Grads should be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is True
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(len(ddp_input))]
def test_distributed_sync_multiple_fwd(accelerator):
# Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards
model, ddp_model, dataloader = get_training_setup(accelerator)
# Do multiple forwards
losses = []
num_iterations = 3
for iteration in range(num_iterations):
ddp_input, ddp_target = next(iter(dataloader)).values()
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
step_model(model, input, target, accelerator)
# Accumulate grads locally
with accelerator.no_sync(ddp_model):
ddp_output = ddp_model(ddp_input)
loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device))
losses.append(loss)
# Do multiple backwards and sync only at the last backward
for iteration in range(num_iterations):
loss = losses[iteration]
if iteration < num_iterations - 1:
# Accumulate grads locally
accelerator.backward(loss)
# DDP model and model should only be in sync after last backward
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
# Grads should not be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is False
), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
else:
# Sync grads if last backward
with accelerator.trigger_sync_in_backward(ddp_model):
accelerator.backward(loss)
# DDP model and model should only be in sync after last backward
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
# Grads should be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is True
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False):
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
accelerator = Accelerator(
split_batches=split_batches,
dispatch_batches=dispatch_batches,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
# Test that context manager behaves properly
model, ddp_model, dataloader = get_training_setup(accelerator)
for iteration, batch in enumerate(dataloader):
ddp_input, ddp_target = batch.values()
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
step_model(model, input, target, accelerator, False)
# Do "gradient accumulation" (noop)
with accelerator.accumulate(ddp_model):
step_model(ddp_model, ddp_input, ddp_target, accelerator)
# DDP model and model should only be in sync when not (iteration % 2 == 0)
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch:
# Grads should be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is True
), f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})"
else:
# Grads should not be in sync
assert (
torch.allclose(param.grad, ddp_param.grad) is False
), f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})"
# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(len(ddp_input))]
GradientState._reset_state()
def test_gradient_accumulation_with_opt_and_scheduler(
split_batches=False, dispatch_batches=False, sync_each_batch=False
):
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
accelerator = Accelerator(
split_batches=split_batches,
dispatch_batches=dispatch_batches,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
# Test that context manager behaves properly
model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)
for iteration, batch in enumerate(dataloader):
ddp_input, ddp_target = batch.values()
# Gather the distributed inputs and targs for the base model
input, target = accelerator.gather((ddp_input, ddp_target))
input, target = input.to(accelerator.device), target.to(accelerator.device)
# Perform our initial ground truth step in non "DDP"
model.train()
ddp_model.train()
step_model(model, input, target, accelerator, False)
opt.step()
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch:
if split_batches:
sched.step()
else:
for _ in range(accelerator.num_processes):
sched.step()
# Perform gradient accumulation under wrapper
with accelerator.accumulate(ddp_model):
step_model(ddp_model, ddp_input, ddp_target, accelerator)
ddp_opt.step()
ddp_sched.step()
# Learning rates should be the same
assert (
opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"]
), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n'
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch
if accelerator.num_processes > 1:
check_model_parameters(
model,
ddp_model,
did_step,
iteration,
rtol=1e-3, # somehow needs a relative tolerance
)
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch:
opt.zero_grad() # needs to be guarded by logic as to when we should zero grads
ddp_opt.zero_grad()
# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
GradientState._reset_state()
def test_dataloader_break():
accelerator = Accelerator()
first_dset = RegressionDataset(length=80)
first_dataloader = DataLoader(first_dset, batch_size=16)
second_dset = RegressionDataset(length=96)
second_dataloader = DataLoader(second_dset, batch_size=16)
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
assert accelerator.gradient_state.active_dataloader is None
for iteration, _ in enumerate(first_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
if iteration < len(first_dataloader) - 1:
assert not accelerator.gradient_state.end_of_dataloader
if iteration == 1:
for batch_num, _ in enumerate(second_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader)
if batch_num < len(second_dataloader) - 1:
assert not accelerator.gradient_state.end_of_dataloader
else:
assert accelerator.gradient_state.end_of_dataloader
else:
assert accelerator.gradient_state.end_of_dataloader
assert accelerator.gradient_state.active_dataloader is None
def main():
accelerator = Accelerator()
state = accelerator.state
if state.local_process_index == 0:
print("**Test `accumulate` gradient accumulation with dataloader break**")
if state.distributed_type != DistributedType.XLA:
test_dataloader_break()
if state.distributed_type == DistributedType.NO:
if state.local_process_index == 0:
print("**Test NOOP `no_sync` context manager**")
test_noop_sync(accelerator)
if state.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_CPU,
):
if state.local_process_index == 0:
print("**Test Distributed `no_sync` context manager**")
test_distributed_sync(accelerator)
if state.local_process_index == 0:
print("**Test Distributed `no_sync` context manager with multiple forwards**")
test_distributed_sync_multiple_fwd(accelerator)
if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU):
for split_batch in [True, False]:
for dispatch_batches in [True, False]:
for sync_each_batch in [True, False]:
if state.local_process_index == 0:
print(
"**Test `accumulate` gradient accumulation, ",
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
)
test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch)
# Currently will break on torch 2.0 +, need to investigate why
if state.local_process_index == 0:
print(
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
"`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**",
)
test_gradient_accumulation_with_opt_and_scheduler()
if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU):
for split_batch in [True, False]:
for dispatch_batches in [True, False]:
for sync_each_batch in [True, False]:
if not split_batch and not dispatch_batches and not sync_each_batch:
continue
if state.local_process_index == 0:
print(
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
)
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()