Spaces:
Runtime error
Runtime error
import unittest | |
import functools as ft | |
import itertools as it | |
from apex import amp | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from utils import common_init, HALF, FLOAT,\ | |
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT | |
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): | |
for fn, typ in it.product(fns, expected.keys()): | |
x = torch.randn(input_shape, dtype=typ).requires_grad_() | |
y = fn(x) | |
test_case.assertEqual(y.type(), expected[typ]) | |
if test_backward: | |
y.float().sum().backward() | |
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ]) | |
class TestBasicCasts(unittest.TestCase): | |
def setUp(self): | |
self.handle = amp.init(enabled=True) | |
common_init(self) | |
def tearDown(self): | |
self.handle._deactivate() | |
def test_linear_is_half(self): | |
m = nn.Linear(self.h, self.h) | |
f = ft.partial(F.linear, weight=m.weight, bias=m.bias) | |
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h)) | |
def test_conv2d_is_half(self): | |
m = nn.Conv2d(self.c, self.c, self.k) | |
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias) | |
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h)) | |
def test_softmax_is_float(self): | |
m = nn.Softmax(dim=1) | |
f = ft.partial(F.softmax, dim=1) | |
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h)) | |
def test_group_norm_is_float(self): | |
m = nn.GroupNorm(num_groups=4, num_channels=self.c) | |
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h)) | |
def test_mse_loss_is_float(self): | |
shape = (self.b, self.h) | |
target = torch.randn(shape) | |
mod = nn.MSELoss() | |
m = lambda x: mod(x, target) | |
f = ft.partial(F.mse_loss, target=target) | |
run_layer_test(self, [m], ALWAYS_FLOAT, shape) | |
def test_relu_is_match(self): | |
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h)) | |
def test_batch_norm_is_match(self): | |
m = nn.BatchNorm2d(num_features=self.c) | |
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, | |
weight=m.weight, bias=m.bias, training=True) | |
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h)) | |
# Test forward-only for BN inference | |
m.eval() | |
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, | |
weight=m.weight, bias=m.bias, training=False) | |
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h), | |
test_backward=False) | |
class TestBannedMethods(unittest.TestCase): | |
def setUp(self): | |
self.handle = amp.init(enabled=True) | |
common_init(self) | |
def tearDown(self): | |
self.handle._deactivate() | |
def bce_common(self, assertion): | |
shape = (self.b, self.h) | |
target = torch.rand(shape) | |
mod = nn.BCELoss() | |
m = lambda x: mod(x, target) | |
f = ft.partial(F.binary_cross_entropy, target=target) | |
for fn in [m, f]: | |
x = torch.rand(shape, dtype=torch.half) | |
assertion(fn, x) | |
def test_bce_raises_by_default(self): | |
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x) | |
self.bce_common(assertion) | |
def test_bce_is_float_with_allow_banned(self): | |
self.handle._deactivate() | |
self.handle = amp.init(enabled=True, allow_banned=True) | |
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT) | |
self.bce_common(assertion) | |
class TestTensorCasts(unittest.TestCase): | |
def setUp(self): | |
self.handle = amp.init(enabled=True) | |
common_init(self) | |
def tearDown(self): | |
self.handle._deactivate() | |
def test_matmul_method_is_half(self): | |
other = torch.randn(self.h, self.h) | |
lhs = lambda x: x.matmul(other) | |
rhs = lambda x: other.matmul(x) | |
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) | |
def test_matmul_op_is_half(self): | |
other = torch.randn(self.h, self.h) | |
lhs = lambda x: x @ other | |
rhs = lambda x: other @ x | |
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h)) | |
def test_pow_method_is_float(self): | |
fn = lambda x: x.pow(2.) | |
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) | |
def test_pow_op_is_float(self): | |
fn = lambda x: x ** 2. | |
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) | |
def test_cpu_is_float(self): | |
fn = lambda x: x.cpu() | |
always_cpu_float = {torch.float: 'torch.FloatTensor', | |
torch.half: 'torch.FloatTensor'} | |
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h)) | |
def test_sum_is_float(self): | |
fn = lambda x: x.sum() | |
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) | |
# TODO: maybe more tests on disabled casting? | |
if __name__ == '__main__': | |
unittest.main() | |