Open-Sora / apex /tests /L0 /run_amp /test_fused_sgd.py
kadirnar's picture
Upload 494 files
8a42f8f verified
raw
history blame
39.2 kB
import unittest
import functools as ft
import itertools as it
from apex import amp
from apex.amp import _amp_state
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
disabled = False
from apex.optimizers import FusedSGD as FusedSGD
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(unique +
torch.arange(2, device='cuda', dtype=torch.float32))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
@staticmethod
def ops(input, weight0, weight1):
return ((input*(weight0.float()))*(weight1.float())).sum()
def forward(self, input):
return self.ops(input, self.weight0, self.weight1)
# Abandon all hope, ye who enter here.
# This is hands down the ugliest code I have ever written, but it succeeds in testing
# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases
# require slightly divergent code in a way that seems near-impossible to genericize into a simple
# cross product or nested loops.
class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], optimizer = amp.initialize(
[model0, model1],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
torch.testing.assert_close(param.grad.float(), reference_grad.float(),
msg="opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
torch.testing.assert_close(model, reference)
torch.testing.assert_close(model, master.to(model.dtype))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()] +
[param.grad.data.clone() for param in model2.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 2)
elif which_backward == 1:
which_models = (1, 2)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], optimizer = amp.initialize(
[model0, model1, model2],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 1:
inj_model = model1
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
torch.testing.assert_close(param.grad.float(), reference_grad.float(),
msg="opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
torch.testing.assert_close(model, reference)
torch.testing.assert_close(model, master.to(model.dtype))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
# Don't do it like this: reference_grads = [[]]*5
# because then it creates a list of 5 references to the same "[]" and appending
# to any of them effectively makes you append to all of them, which multiplies
# the resulting size of reference_grads by 5x and needless to say makes the test fail.
reference_grads = [[], [], [], [], []]
final_params = [None, None, None, None, None]
for i in range(2):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer0.step()
optimizer1.step()
final_params[0] = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
def what_got_skipped(which_iter, which_backward):
if which_iter == 0 and which_backward == 0:
return 1
if which_iter == 0 and which_backward == 1:
return 2
if which_iter == 1 and which_backward == 0:
return 3
if which_iter == 1 and which_backward == 1:
return 4
return 0
for which_iter in (0,1):
for which_backward in (0,1):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
for i in range(3):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
if i != which_iter:
reference_grads[what_got_skipped(which_iter, which_backward)].append(
[param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
if i == which_iter:
if which_backward == 0:
optimizer1.step()
else:
optimizer0.step()
else:
optimizer0.step()
optimizer1.step()
final_params[what_got_skipped(which_iter, which_backward)] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], [optimizer0, optimizer1] = amp.initialize(
[model0, model1],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
torch.testing.assert_close(param.grad.float(), reference_grad.float())
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward)]):
torch.testing.assert_close(model, reference)
torch.testing.assert_close(model, master.to(model.dtype))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
# Again, can't do this: reference_grads = [[]]*9
reference_grads = [[], [], [], [], [], [], [], [], []]
final_params = [None, None, None, None, None, None, None, None, None]
for i in range(2):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
loss0.backward()
loss1.backward()
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer0.step()
optimizer1.step()
final_params[0] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
def what_got_skipped(which_iter, which_backward, which_model):
if which_iter == 0:
if which_backward == 0:
if which_model == 0:
return 1
if which_model == 1:
return 2
if which_backward == 1:
if which_model == 2:
return 3
if which_model == 1:
return 4
if which_iter == 1:
if which_backward == 0:
if which_model == 0:
return 5
if which_model == 1:
return 6
if which_backward == 1:
if which_model == 2:
return 7
if which_model == 1:
return 8
return 0
for which_iter in (0,1):
for which_backward in (0,1):
if which_backward == 0:
which_models = (0,1)
if which_backward == 1:
which_models = (2,1)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
for i in range(3):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
loss0.backward()
loss1.backward()
if i != which_iter:
reference_grads[what_got_skipped(which_iter,
which_backward, which_model)].append(
[param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
if i == which_iter:
if which_backward == 0:
# if which_model == 0:
optimizer1.step()
# if which_model == 1:
# optimizer1.step()
if which_backward == 1:
# if which_model == 2:
# optimizer0.step()
# if which_model == 1:
continue
else:
optimizer0.step()
optimizer1.step()
final_params[what_got_skipped(which_iter, which_backward, which_model)] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 1)
elif which_backward == 1:
which_models = (2, 1)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
[model0, model1, model2],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 2:
inj_model = model2
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf,
which_backward, which_model)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
torch.testing.assert_close(param.grad.float(), reference_grad.float())
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
torch.testing.assert_close(model, reference)
torch.testing.assert_close(model, master.to(model.dtype))
if opt_level == "O1":
_amp_state.handle._deactivate()
if __name__ == '__main__':
unittest.main()