Spaces:
Runtime error
Runtime error
File size: 4,823 Bytes
8a42f8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import unittest
import functools as ft
import itertools as it
from apex import amp
from apex.amp import _amp_state
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
def get_reference_grad(i, w, ops):
# Creating new tensors ensures, among other things, that the new tensors are not in the cache.
# In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.
fp32_i = i.detach().clone().float()
fp32_w = w.detach().clone().float().requires_grad_()
loss = ops(fp32_i, fp32_w)
loss.backward()
return fp32_w.grad
class WhitelistModule(torch.nn.Module):
def __init__(self, dtype):
super(WhitelistModule, self).__init__()
self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))
@staticmethod
def ops(input, weight):
return (input.mm(weight)).mm(weight).sum()
def forward(self, input):
return self.ops(input, self.weight)
class BlacklistModule(torch.nn.Module):
def __init__(self, dtype):
super(BlacklistModule, self).__init__()
self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
@staticmethod
def ops(input, weight):
return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum()
def forward(self, input):
return self.ops(input, self.weight)
class PromoteModule(torch.nn.Module):
def __init__(self, dtype):
super(PromoteModule, self).__init__()
self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
@staticmethod
def ops(input, weight):
return ((input*weight)*weight).sum()
def forward(self, input):
return self.ops(input, self.weight)
class TestCache(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
def train_eval_train_test(self, module, t):
model = module(t).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
_amp_state.allow_incoming_model_not_fp32 = True
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
_amp_state.allow_incoming_model_not_fp32 = False
def training_step():
for param in model.parameters():
param.grad = None
loss = model(self.x).sum()
_amp_state.loss_scalers[0]._loss_scale = 4.0
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)
self.assertEqual(model.weight.grad.type(), model.weight.type())
reference_grad = get_reference_grad(self.x, model.weight, model.ops)
# Currently there's no difference in the allclose calls, so no need for branching,
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if model.weight.grad.type() == "torch.cuda.HalfTensor":
torch.testing.assert_close(model.weight.grad.float(), reference_grad)
elif model.weight.grad.type() == "torch.cuda.FloatTensor":
torch.testing.assert_close(model.weight.grad.float(), reference_grad)
else:
raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))
model.weight.data -= 1.
# Simulates first epoch
training_step()
# Simulates eval
with torch.no_grad():
loss = model(self.x).sum()
# Simulates resuming training after eval
training_step()
_amp_state.handle._deactivate()
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
def test_whitelist_module_fp16_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float16)
def test_whitelist_module_fp32_weight(self):
self.train_eval_train_test(WhitelistModule, torch.float32)
def test_blacklist_module_fp16_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float16)
def test_blacklist_module_fp32_weight(self):
self.train_eval_train_test(BlacklistModule, torch.float32)
def test_promote_module_fp16_weight(self):
self.train_eval_train_test(PromoteModule, torch.float16)
def test_promote_module_fp32_weight(self):
self.train_eval_train_test(PromoteModule, torch.float32)
if __name__ == '__main__':
unittest.main()
|