File size: 2,545 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import unittest
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention


class TestSparseMultiheadAttention(unittest.TestCase):
    def test_sparse_multihead_attention(self):
        attn_weights = torch.randn(1, 8, 8)
        bidirectional_sparse_mask = torch.tensor([
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
            ])

        bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
        bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
        torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))

        sparse_mask = torch.tensor([
                [0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
                 float('-inf'), float('-inf')],
                [0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                [0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                [0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
            ])

        attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
        attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)

        torch.all(torch.eq(attention_sparse_mask, sparse_mask))


if __name__ == '__main__':
    unittest.main()