Spaces:
Runtime error
Runtime error
File size: 4,506 Bytes
8a42f8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import unittest
from apex import amp
import random
import torch
from torch import nn
from utils import common_init, HALF
class TestRnnCells(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_cell_test(self, cell, state_tuple=False):
shape = (self.b, self.h)
for typ in [torch.float, torch.half]:
xs = [torch.randn(shape, dtype=typ).requires_grad_()
for _ in range(self.t)]
hidden_fn = lambda: torch.zeros(shape, dtype=typ)
if state_tuple:
hidden = (hidden_fn(), hidden_fn())
else:
hidden = hidden_fn()
outputs = []
for i in range(self.t):
hidden = cell(xs[i], hidden)
if state_tuple:
output = hidden[0]
else:
output = hidden
outputs.append(output)
for y in outputs:
self.assertEqual(y.type(), HALF)
outputs[-1].float().sum().backward()
for i, x in enumerate(xs):
self.assertEqual(x.grad.dtype, x.dtype)
def test_rnn_cell_is_half(self):
cell = nn.RNNCell(self.h, self.h)
self.run_cell_test(cell)
def test_gru_cell_is_half(self):
cell = nn.GRUCell(self.h, self.h)
self.run_cell_test(cell)
def test_lstm_cell_is_half(self):
cell = nn.LSTMCell(self.h, self.h)
self.run_cell_test(cell, state_tuple=True)
class TestRnns(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
for typ in [torch.float, torch.half]:
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
self.b, self.h), dtype=typ)
if state_tuple:
hidden = (hidden_fn(), hidden_fn())
else:
hidden = hidden_fn()
output, _ = rnn(x, hidden)
self.assertEqual(output.type(), HALF)
output[-1, :, :].float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
def test_rnn_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,
nonlinearity='relu', bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
def test_gru_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
def test_lstm_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
def test_rnn_packed_sequence(self):
num_layers = 2
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
for typ in [torch.float, torch.half]:
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
reverse=True)
# `pack_padded_sequence` breaks if default tensor type is non-CPU
torch.set_default_tensor_type(torch.FloatTensor)
lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
output, _ = rnn(packed_seq, hidden)
self.assertEqual(output.data.type(), HALF)
output.data.float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
if __name__ == '__main__':
unittest.main()
|