File size: 5,085 Bytes
8a42f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

from utils import common_init, HALF, FLOAT,\
    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
    for fn, typ in it.product(fns, expected.keys()):
        x = torch.randn(input_shape, dtype=typ).requires_grad_()
        y = fn(x)
        test_case.assertEqual(y.type(), expected[typ])
        if test_backward:
            y.float().sum().backward()
            test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])

class TestBasicCasts(unittest.TestCase):
    def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def test_linear_is_half(self):
        m = nn.Linear(self.h, self.h)
        f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
        run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))

    def test_conv2d_is_half(self):
        m = nn.Conv2d(self.c, self.c, self.k)
        f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
        run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))

    def test_softmax_is_float(self):
        m = nn.Softmax(dim=1)
        f = ft.partial(F.softmax, dim=1)
        run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))

    def test_group_norm_is_float(self):
        m = nn.GroupNorm(num_groups=4, num_channels=self.c)
        run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))

    def test_mse_loss_is_float(self):
        shape = (self.b, self.h)
        target = torch.randn(shape)
        mod = nn.MSELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.mse_loss, target=target)
        run_layer_test(self, [m], ALWAYS_FLOAT, shape)

    def test_relu_is_match(self):
        run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))

    def test_batch_norm_is_match(self):
        m = nn.BatchNorm2d(num_features=self.c)
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=True)
        run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))

        # Test forward-only for BN inference
        m.eval()
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=False)
        run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
                            test_backward=False)

class TestBannedMethods(unittest.TestCase):
    def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def bce_common(self, assertion):
        shape = (self.b, self.h)
        target = torch.rand(shape)
        mod = nn.BCELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.binary_cross_entropy, target=target)
        for fn in [m, f]:
            x = torch.rand(shape, dtype=torch.half)
            assertion(fn, x)

    def test_bce_raises_by_default(self):
        assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
        self.bce_common(assertion)

    def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, allow_banned=True)
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
        self.bce_common(assertion)

class TestTensorCasts(unittest.TestCase):
    def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def test_matmul_method_is_half(self):
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x.matmul(other)
        rhs = lambda x: other.matmul(x)
        run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))

    def test_matmul_op_is_half(self):
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x @ other
        rhs = lambda x: other @ x
        run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))

    def test_pow_method_is_float(self):
        fn = lambda x: x.pow(2.)
        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))

    def test_pow_op_is_float(self):
        fn = lambda x: x ** 2.
        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))

    def test_cpu_is_float(self):
        fn = lambda x: x.cpu()
        always_cpu_float = {torch.float: 'torch.FloatTensor',
                            torch.half: 'torch.FloatTensor'}
        run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))

    def test_sum_is_float(self):
        fn = lambda x: x.sum()
        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))

    # TODO: maybe more tests on disabled casting?

if __name__ == '__main__':
    unittest.main()