Spaces:
Running
Running
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from copy import deepcopy | |
import torch | |
import torch.nn.functional as F | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LambdaLR | |
from torch.utils.data import DataLoader | |
from accelerate.accelerator import Accelerator, GradientAccumulationPlugin | |
from accelerate.state import GradientState | |
from accelerate.test_utils import RegressionDataset, RegressionModel | |
from accelerate.utils import DistributedType, set_seed | |
def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs): | |
for param, grad_param in zip(model_a.parameters(), model_b.parameters()): | |
if not param.requires_grad: | |
continue | |
if not did_step: | |
# Grads should not be in sync | |
assert ( | |
torch.allclose(param.grad, grad_param.grad, **kwargs) is False | |
), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" | |
else: | |
# Grads should be in sync | |
assert ( | |
torch.allclose(param.grad, grad_param.grad, **kwargs) is True | |
), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" | |
def step_model(model, input, target, accelerator, do_backward=True): | |
model.train() | |
output = model(input) | |
loss = F.mse_loss(output, target.to(output.device)) | |
if not do_backward: | |
loss /= accelerator.gradient_accumulation_steps | |
loss.backward() | |
else: | |
accelerator.backward(loss) | |
def get_training_setup(accelerator, sched=False): | |
"Returns everything needed to perform basic training" | |
set_seed(42) | |
model = RegressionModel() | |
ddp_model = deepcopy(model) | |
dset = RegressionDataset(length=80) | |
dataloader = DataLoader(dset, batch_size=16) | |
model.to(accelerator.device) | |
if sched: | |
opt = AdamW(params=model.parameters(), lr=1e-3) | |
ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3) | |
sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65) | |
ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65) | |
# Make a copy of `model` | |
if sched: | |
ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader) | |
else: | |
ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) | |
if sched: | |
return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched) | |
return model, ddp_model, dataloader | |
def test_noop_sync(accelerator): | |
# Test when on a single CPU or GPU that the context manager does nothing | |
model, ddp_model, dataloader = get_training_setup(accelerator) | |
# Use a single batch | |
ddp_input, ddp_target = next(iter(dataloader)).values() | |
for iteration in range(3): | |
# Gather the distributed inputs and targs for the base model | |
input, target = accelerator.gather((ddp_input, ddp_target)) | |
input, target = input.to(accelerator.device), target.to(accelerator.device) | |
# Perform our initial ground truth step in non "DDP" | |
step_model(model, input, target, accelerator) | |
# Do "gradient accumulation" (noop) | |
if iteration % 2 == 0: | |
# Accumulate grads locally | |
with accelerator.no_sync(ddp_model): | |
step_model(ddp_model, ddp_input, ddp_target, accelerator) | |
else: | |
# Sync grads | |
step_model(ddp_model, ddp_input, ddp_target, accelerator) | |
# Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync | |
check_model_parameters(model, ddp_model, True, iteration) | |
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): | |
if not param.requires_grad: | |
continue | |
assert torch.allclose( | |
param.grad, ddp_param.grad | |
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" | |
# Shuffle ddp_input on each iteration | |
torch.manual_seed(1337 + iteration) | |
ddp_input = ddp_input[torch.randperm(len(ddp_input))] | |
def test_distributed_sync(accelerator): | |
# Test on distributed setup that context manager behaves properly | |
model, ddp_model, dataloader = get_training_setup(accelerator) | |
# Use a single batch | |
ddp_input, ddp_target = next(iter(dataloader)).values() | |
for iteration in range(3): | |
# Gather the distributed inputs and targs for the base model | |
input, target = accelerator.gather((ddp_input, ddp_target)) | |
input, target = input.to(accelerator.device), target.to(accelerator.device) | |
# Perform our initial ground truth step in non "DDP" | |
step_model(model, input, target, accelerator) | |
# Do "gradient accumulation" (noop) | |
if iteration % 2 == 0: | |
# Accumulate grads locally | |
with accelerator.no_sync(ddp_model): | |
step_model(ddp_model, ddp_input, ddp_target, accelerator) | |
else: | |
# Sync grads | |
step_model(ddp_model, ddp_input, ddp_target, accelerator) | |
# DDP model and model should only be in sync when not (iteration % 2 == 0) | |
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): | |
if not param.requires_grad: | |
continue | |
if iteration % 2 == 0: | |
# Grads should not be in sync | |
assert ( | |
torch.allclose(param.grad, ddp_param.grad) is False | |
), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" | |
else: | |
# Grads should be in sync | |
assert ( | |
torch.allclose(param.grad, ddp_param.grad) is True | |
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" | |
# Shuffle ddp_input on each iteration | |
torch.manual_seed(1337 + iteration) | |
ddp_input = ddp_input[torch.randperm(len(ddp_input))] | |
def test_distributed_sync_multiple_fwd(accelerator): | |
# Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards | |
model, ddp_model, dataloader = get_training_setup(accelerator) | |
# Do multiple forwards | |
losses = [] | |
num_iterations = 3 | |
for iteration in range(num_iterations): | |
ddp_input, ddp_target = next(iter(dataloader)).values() | |
# Gather the distributed inputs and targs for the base model | |
input, target = accelerator.gather((ddp_input, ddp_target)) | |
input, target = input.to(accelerator.device), target.to(accelerator.device) | |
# Perform our initial ground truth step in non "DDP" | |
step_model(model, input, target, accelerator) | |
# Accumulate grads locally | |
with accelerator.no_sync(ddp_model): | |
ddp_output = ddp_model(ddp_input) | |
loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device)) | |
losses.append(loss) | |
# Do multiple backwards and sync only at the last backward | |
for iteration in range(num_iterations): | |
loss = losses[iteration] | |
if iteration < num_iterations - 1: | |
# Accumulate grads locally | |
accelerator.backward(loss) | |
# DDP model and model should only be in sync after last backward | |
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): | |
if not param.requires_grad: | |
continue | |
# Grads should not be in sync | |
assert ( | |
torch.allclose(param.grad, ddp_param.grad) is False | |
), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" | |
else: | |
# Sync grads if last backward | |
with accelerator.trigger_sync_in_backward(ddp_model): | |
accelerator.backward(loss) | |
# DDP model and model should only be in sync after last backward | |
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): | |
if not param.requires_grad: | |
continue | |
# Grads should be in sync | |
assert ( | |
torch.allclose(param.grad, ddp_param.grad) is True | |
), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" | |
def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False): | |
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) | |
accelerator = Accelerator( | |
split_batches=split_batches, | |
dispatch_batches=dispatch_batches, | |
gradient_accumulation_plugin=gradient_accumulation_plugin, | |
) | |
# Test that context manager behaves properly | |
model, ddp_model, dataloader = get_training_setup(accelerator) | |
for iteration, batch in enumerate(dataloader): | |
ddp_input, ddp_target = batch.values() | |
# Gather the distributed inputs and targs for the base model | |
input, target = accelerator.gather((ddp_input, ddp_target)) | |
input, target = input.to(accelerator.device), target.to(accelerator.device) | |
# Perform our initial ground truth step in non "DDP" | |
step_model(model, input, target, accelerator, False) | |
# Do "gradient accumulation" (noop) | |
with accelerator.accumulate(ddp_model): | |
step_model(ddp_model, ddp_input, ddp_target, accelerator) | |
# DDP model and model should only be in sync when not (iteration % 2 == 0) | |
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): | |
if not param.requires_grad: | |
continue | |
if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch: | |
# Grads should be in sync | |
assert ( | |
torch.allclose(param.grad, ddp_param.grad) is True | |
), f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" | |
else: | |
# Grads should not be in sync | |
assert ( | |
torch.allclose(param.grad, ddp_param.grad) is False | |
), f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" | |
# Shuffle ddp_input on each iteration | |
torch.manual_seed(1337 + iteration) | |
ddp_input = ddp_input[torch.randperm(len(ddp_input))] | |
GradientState._reset_state() | |
def test_gradient_accumulation_with_opt_and_scheduler( | |
split_batches=False, dispatch_batches=False, sync_each_batch=False | |
): | |
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) | |
accelerator = Accelerator( | |
split_batches=split_batches, | |
dispatch_batches=dispatch_batches, | |
gradient_accumulation_plugin=gradient_accumulation_plugin, | |
) | |
# Test that context manager behaves properly | |
model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True) | |
for iteration, batch in enumerate(dataloader): | |
ddp_input, ddp_target = batch.values() | |
# Gather the distributed inputs and targs for the base model | |
input, target = accelerator.gather((ddp_input, ddp_target)) | |
input, target = input.to(accelerator.device), target.to(accelerator.device) | |
# Perform our initial ground truth step in non "DDP" | |
model.train() | |
ddp_model.train() | |
step_model(model, input, target, accelerator, False) | |
opt.step() | |
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch: | |
if split_batches: | |
sched.step() | |
else: | |
for _ in range(accelerator.num_processes): | |
sched.step() | |
# Perform gradient accumulation under wrapper | |
with accelerator.accumulate(ddp_model): | |
step_model(ddp_model, ddp_input, ddp_target, accelerator) | |
ddp_opt.step() | |
ddp_sched.step() | |
# Learning rates should be the same | |
assert ( | |
opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"] | |
), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n' | |
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch | |
if accelerator.num_processes > 1: | |
check_model_parameters( | |
model, | |
ddp_model, | |
did_step, | |
iteration, | |
rtol=1e-3, # somehow needs a relative tolerance | |
) | |
if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch: | |
opt.zero_grad() # needs to be guarded by logic as to when we should zero grads | |
ddp_opt.zero_grad() | |
# Shuffle ddp_input on each iteration | |
torch.manual_seed(1337 + iteration) | |
GradientState._reset_state() | |
def test_dataloader_break(): | |
accelerator = Accelerator() | |
first_dset = RegressionDataset(length=80) | |
first_dataloader = DataLoader(first_dset, batch_size=16) | |
second_dset = RegressionDataset(length=96) | |
second_dataloader = DataLoader(second_dset, batch_size=16) | |
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) | |
assert accelerator.gradient_state.active_dataloader is None | |
for iteration, _ in enumerate(first_dataloader): | |
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) | |
if iteration < len(first_dataloader) - 1: | |
assert not accelerator.gradient_state.end_of_dataloader | |
if iteration == 1: | |
for batch_num, _ in enumerate(second_dataloader): | |
assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader) | |
if batch_num < len(second_dataloader) - 1: | |
assert not accelerator.gradient_state.end_of_dataloader | |
else: | |
assert accelerator.gradient_state.end_of_dataloader | |
else: | |
assert accelerator.gradient_state.end_of_dataloader | |
assert accelerator.gradient_state.active_dataloader is None | |
def main(): | |
accelerator = Accelerator() | |
state = accelerator.state | |
if state.local_process_index == 0: | |
print("**Test `accumulate` gradient accumulation with dataloader break**") | |
if state.distributed_type != DistributedType.XLA: | |
test_dataloader_break() | |
if state.distributed_type == DistributedType.NO: | |
if state.local_process_index == 0: | |
print("**Test NOOP `no_sync` context manager**") | |
test_noop_sync(accelerator) | |
if state.distributed_type in ( | |
DistributedType.MULTI_GPU, | |
DistributedType.MULTI_NPU, | |
DistributedType.MULTI_MLU, | |
DistributedType.MULTI_CPU, | |
): | |
if state.local_process_index == 0: | |
print("**Test Distributed `no_sync` context manager**") | |
test_distributed_sync(accelerator) | |
if state.local_process_index == 0: | |
print("**Test Distributed `no_sync` context manager with multiple forwards**") | |
test_distributed_sync_multiple_fwd(accelerator) | |
if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU): | |
for split_batch in [True, False]: | |
for dispatch_batches in [True, False]: | |
for sync_each_batch in [True, False]: | |
if state.local_process_index == 0: | |
print( | |
"**Test `accumulate` gradient accumulation, ", | |
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", | |
) | |
test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch) | |
# Currently will break on torch 2.0 +, need to investigate why | |
if state.local_process_index == 0: | |
print( | |
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ", | |
"`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**", | |
) | |
test_gradient_accumulation_with_opt_and_scheduler() | |
if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU): | |
for split_batch in [True, False]: | |
for dispatch_batches in [True, False]: | |
for sync_each_batch in [True, False]: | |
if not split_batch and not dispatch_batches and not sync_each_batch: | |
continue | |
if state.local_process_index == 0: | |
print( | |
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ", | |
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", | |
) | |
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch) | |
def _mp_fn(index): | |
# For xla_spawn (TPUs) | |
main() | |
if __name__ == "__main__": | |
main() | |