Spaces:
Runtime error
Runtime error
import unittest | |
import itertools as it | |
from apex import amp | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from utils import common_init, HALF, FLOAT, DTYPES | |
class TestPromotion(unittest.TestCase): | |
def setUp(self): | |
self.handle = amp.init(enabled=True) | |
common_init(self) | |
def tearDown(self): | |
self.handle._deactivate() | |
def run_binary_promote_test(self, fns, input_shape, x_inplace=False): | |
type_pairs = it.product(DTYPES, DTYPES) | |
for fn, (xtype, ytype) in it.product(fns, type_pairs): | |
x = torch.randn(input_shape, dtype=xtype).requires_grad_() | |
x_leaf = x | |
if x_inplace: | |
# We need a non-leaf to call in place on | |
x = x.clone() | |
y = torch.randn(input_shape, dtype=ytype) | |
out = fn(x, y) | |
if x_inplace: | |
# In place: always match xtype | |
self.assertEqual(out.type(), x.type()) | |
else: | |
# Out of place: match widest type | |
if xtype == torch.float or ytype == torch.float: | |
self.assertEqual(out.type(), FLOAT) | |
else: | |
self.assertEqual(out.type(), HALF) | |
out.float().sum().backward() | |
self.assertEqual(x_leaf.grad.dtype, xtype) | |
def test_atan2_matches_widest(self): | |
fns = [lambda x, y : torch.atan2(x, y), | |
lambda x, y : x.atan2(y)] | |
self.run_binary_promote_test(fns, (self.b,)) | |
def test_mul_matches_widest(self): | |
fns = [lambda x, y : torch.mul(x, y), | |
lambda x, y: x.mul(y)] | |
self.run_binary_promote_test(fns, (self.b,)) | |
def test_cat_matches_widest(self): | |
shape = self.b | |
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)] | |
x_float = torch.randn(shape) | |
out = torch.cat(ys + [x_float]) | |
self.assertEqual(out.type(), FLOAT) | |
x_half = torch.randn(shape, dtype=torch.half) | |
out = torch.cat(ys + [x_half]) | |
self.assertEqual(out.type(), HALF) | |
def test_inplace_exp_is_error_for_half(self): | |
xs = torch.randn(self.b) | |
xs.exp_() | |
self.assertEqual(xs.type(), FLOAT) | |
xs = torch.randn(self.b, dtype=torch.half) | |
with self.assertRaises(NotImplementedError): | |
xs.exp_() | |
def test_inplace_add_matches_self(self): | |
fn = lambda x, y: x.add_(y) | |
self.run_binary_promote_test([fn], (self.b,), x_inplace=True) | |
if __name__ == '__main__': | |
unittest.main() | |