Enzo Reis de Oliveira commited on
Commit
b60e08a
·
1 Parent(s): 30f063f

Fixing again

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. smi-ted/inference/smi_ted_light/.gitattributes +2 -0
  2. smi-ted/inference/smi_ted_light/fast_transformers/__init__.py +15 -0
  3. smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py +128 -0
  4. smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
  5. smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py +20 -0
  6. smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc +0 -0
  7. smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc +0 -0
  8. smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc +0 -0
  9. smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc +0 -0
  10. smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py +113 -0
  11. smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py +116 -0
  12. smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py +195 -0
  13. smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py +66 -0
  14. smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py +88 -0
  15. smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py +95 -0
  16. smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py +268 -0
  17. smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py +257 -0
  18. smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py +92 -0
  19. smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py +101 -0
  20. smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py +166 -0
  21. smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py +17 -0
  22. smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc +0 -0
  23. smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc +0 -0
  24. smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc +0 -0
  25. smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py +61 -0
  26. smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py +126 -0
  27. smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py +59 -0
  28. smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc +0 -0
  29. smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc +0 -0
  30. smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc +0 -0
  31. smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc +0 -0
  32. smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py +139 -0
  33. smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py +67 -0
  34. smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py +550 -0
  35. smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py +78 -0
  36. smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
  37. smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py +0 -0
  38. smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py +115 -0
  39. smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
  40. smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py +10 -0
  41. smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc +0 -0
  42. smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc +0 -0
  43. smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc +0 -0
  44. smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc +0 -0
  45. smi-ted/inference/smi_ted_light/fast_transformers/events/event.py +51 -0
  46. smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py +92 -0
  47. smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py +141 -0
  48. smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py +12 -0
  49. smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc +0 -0
  50. smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc +0 -0
smi-ted/inference/smi_ted_light/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ smi-ted/inference/smi_ted_light/fast_transformers/**/*.so filter=lfs diff=lfs merge=lfs -text
2
+ *.so filter=lfs diff=lfs merge=lfs -text
smi-ted/inference/smi_ted_light/fast_transformers/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Provide a library with fast transformer implementations."""
8
+
9
+ __author__ = "Angelos Katharopoulos, Apoorv Vyas"
10
+ __copyright__ = "Copyright (c) 2020 Idiap Research Institute"
11
+ __license__ = "MIT"
12
+ __maintainer__ = "Angelos Katharopoulos, Apoorv Vyas"
13
14
+ __url__ = "https://github.com/idiap/fast-transformers"
15
+ __version__ = "0.4.0"
smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+
8
+ import torch
9
+
10
+ from .aggregate_cpu import aggregate as aggregate_cpu, \
11
+ broadcast as broadcast_cpu
12
+ try:
13
+ from .aggregate_cuda import aggregate as aggregate_gpu, \
14
+ broadcast as broadcast_gpu
15
+ from .clustered_aggregate_cuda import \
16
+ clustered_broadcast as clustered_broadcast_gpu, \
17
+ clustered_aggregate as clustered_aggregate_gpu
18
+
19
+ except ImportError:
20
+ pass
21
+
22
+
23
+ def aggregate(X, G, F, Y=None):
24
+ device = X.device
25
+ if Y is None:
26
+ Y = torch.zeros(
27
+ F.shape + (X.shape[-1],),
28
+ device=device,
29
+ dtype=X.dtype
30
+ )
31
+ else:
32
+ Y.zero_()
33
+
34
+ if device.type == "cpu":
35
+ aggregate_cpu(X, G, F, Y)
36
+ else:
37
+ aggregate_gpu(X, G, F, Y)
38
+
39
+ return Y
40
+
41
+
42
+ def broadcast(Y, G, F, X=None):
43
+ device = Y.device
44
+ if X is None:
45
+ X = torch.zeros(
46
+ G.shape + (Y.shape[-1],),
47
+ device=device,
48
+ dtype=Y.dtype
49
+ )
50
+
51
+ if device.type == "cpu":
52
+ broadcast_cpu(Y, G, F, X)
53
+ else:
54
+ broadcast_gpu(Y, G, F, X)
55
+
56
+ return X
57
+
58
+
59
+ # Divide the cluster into groups of equal size
60
+ # as constrained by the shared memory
61
+ def set_group(C, E):
62
+ C_per_block = int(192 * 64 / (E+1))
63
+ G_min = (C + C_per_block - 1) // C_per_block
64
+ for G in range(G_min, C+1):
65
+ if C % G == 0:
66
+ return G
67
+
68
+
69
+ def clustered_broadcast(Y, groups, counts, factors, X=None):
70
+ device = Y.device
71
+ if X is None:
72
+ X = torch.zeros(
73
+ groups.shape + (Y.shape[-1],),
74
+ device=device,
75
+ dtype=Y.dtype
76
+ )
77
+ if device.type == "cpu":
78
+ broadcast_cpu(Y, groups, factors, X)
79
+ else:
80
+ N, H, C, E = Y.shape
81
+ _, _, L, _ = X.shape
82
+
83
+ # Following are some booking keeping parameters to facilitate the
84
+ # broadcast kernel that takes advantage of clustering
85
+ # More information can be found in the cuda file
86
+ with torch.no_grad():
87
+ threads = 256
88
+ G = set_group(C, E)
89
+ group_counts = counts.view(N, H, G, -1).sum(-1)
90
+ block_counts = (group_counts + threads - 1) // threads
91
+ total_blocks = block_counts.sum().item()
92
+ indx_maps = torch.ones(
93
+ (total_blocks, 5),
94
+ device=X.device,
95
+ dtype=torch.int32
96
+ )
97
+
98
+ clustered_broadcast_gpu(
99
+ Y,
100
+ groups,
101
+ factors,
102
+ X,
103
+ block_counts.int(),
104
+ group_counts.int(),
105
+ threads,
106
+ G,
107
+ total_blocks,
108
+ indx_maps
109
+ )
110
+ return X
111
+
112
+
113
+ def clustered_aggregate(X, G, F, lengths, Y=None):
114
+ device = X.device
115
+ if Y is None:
116
+ Y = torch.zeros(
117
+ F.shape + (X.shape[-1],),
118
+ device=device,
119
+ dtype=X.dtype
120
+ )
121
+ else:
122
+ Y.zero_()
123
+
124
+ if device.type == "cpu":
125
+ aggregate_cpu(X, G, F, Y)
126
+ else:
127
+ clustered_aggregate_gpu(X, G, F, lengths, Y)
128
+ return Y
smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bccb1a374d4649aaef6361cc41c9ffb471086464cc07a0d6d21c5b65adb0711
3
+ size 138248
smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implementations of different types of attention mechanisms."""
8
+
9
+
10
+ from .attention_layer import AttentionLayer
11
+ from .full_attention import FullAttention
12
+ from .linear_attention import LinearAttention
13
+ #from .causal_linear_attention import CausalLinearAttention
14
+ #from .clustered_attention import ClusteredAttention
15
+ #from .improved_clustered_attention import ImprovedClusteredAttention
16
+ #from .reformer_attention import ReformerAttention
17
+ #from .conditional_full_attention import ConditionalFullAttention
18
+ #from .exact_topk_attention import ExactTopKAttention
19
+ #from .improved_clustered_causal_attention import ImprovedClusteredCausalAttention
20
+ #from .local_attention import LocalAttention
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (502 Bytes). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc ADDED
Binary file (4.14 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc ADDED
Binary file (2.96 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """The base attention layer performs all the query key value projections and
8
+ output projections leaving the implementation of the attention to the inner
9
+ attention module.
10
+
11
+ The transformer layers, however, are agnostic of the attention implementation
12
+ and any layer that implements the same interface can substitute for the
13
+ attention layer.
14
+ """
15
+
16
+ from torch.nn import Linear, Module
17
+
18
+ from ..events import EventDispatcher, QKVEvent
19
+
20
+
21
+ class AttentionLayer(Module):
22
+ """Implement the attention layer. Namely project the inputs to multi-head
23
+ queries, keys and values, call the attention implementation and then
24
+ reproject the output.
25
+
26
+ It can be thought of as a decorator (see decorator design patter) of an
27
+ attention layer.
28
+
29
+ Arguments
30
+ ---------
31
+ attention: Specific inner attention implementation that just computes a
32
+ weighted average of values given a similarity of queries and
33
+ keys.
34
+ d_model: The input feature dimensionality
35
+ n_heads: The number of heads for the multi head attention
36
+ d_keys: The dimensionality of the keys/queries
37
+ (default: d_model/n_heads)
38
+ d_values: The dimensionality of the values (default: d_model/n_heads)
39
+ event_dispatcher: str or EventDispatcher instance to be used by this
40
+ module for dispatching events (default: the default
41
+ global dispatcher)
42
+ """
43
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
44
+ d_values=None, event_dispatcher=""):
45
+ super(AttentionLayer, self).__init__()
46
+
47
+ # Fill d_keys and d_values
48
+ d_keys = d_keys or (d_model//n_heads)
49
+ d_values = d_values or (d_model//n_heads)
50
+
51
+ self.inner_attention = attention
52
+ self.query_projection = Linear(d_model, d_keys * n_heads)
53
+ self.key_projection = Linear(d_model, d_keys * n_heads)
54
+ self.value_projection = Linear(d_model, d_values * n_heads)
55
+ self.out_projection = Linear(d_values * n_heads, d_model)
56
+ self.n_heads = n_heads
57
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
58
+
59
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
60
+ key_lengths):
61
+ """Apply attention to the passed in queries/keys/values after
62
+ projecting them to multiple heads.
63
+
64
+ In the argument description we make use of the following sizes
65
+
66
+ - N: the batch size
67
+ - L: The maximum length of the queries
68
+ - S: The maximum length of the keys (the actual length per sequence
69
+ is given by the length mask)
70
+ - D: The input feature dimensionality passed in the constructor as
71
+ 'd_model'
72
+
73
+ Arguments
74
+ ---------
75
+ queries: (N, L, D) The tensor containing the queries
76
+ keys: (N, S, D) The tensor containing the keys
77
+ values: (N, S, D) The tensor containing the values
78
+ attn_mask: An implementation of BaseMask that encodes where each
79
+ query can attend to
80
+ query_lengths: An implementation of BaseMask that encodes how
81
+ many queries each sequence in the batch consists of
82
+ key_lengths: An implementation of BaseMask that encodes how
83
+ many queries each sequence in the batch consists of
84
+
85
+ Returns
86
+ -------
87
+ The new value for each query as a tensor of shape (N, L, D).
88
+ """
89
+ # Extract the dimensions into local variables
90
+ N, L, _ = queries.shape
91
+ _, S, _ = keys.shape
92
+ H = self.n_heads
93
+
94
+ # Project the queries/keys/values
95
+ queries = self.query_projection(queries).view(N, L, H, -1)
96
+ keys = self.key_projection(keys).view(N, S, H, -1)
97
+ values = self.value_projection(values).view(N, S, H, -1)
98
+
99
+ # Let the world know of the qkv
100
+ self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
101
+
102
+ # Compute the attention
103
+ new_values = self.inner_attention(
104
+ queries,
105
+ keys,
106
+ values,
107
+ attn_mask,
108
+ query_lengths,
109
+ key_lengths
110
+ ).view(N, L, -1)
111
+
112
+ # Project the output and return
113
+ return self.out_projection(new_values)
smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement causally masked linear attention."""
8
+
9
+ import torch
10
+ from torch.nn import Module
11
+
12
+ from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
13
+ EventDispatcherInstance
14
+ from ..events import EventDispatcher
15
+ from ..causal_product import causal_dot_product
16
+ from ..feature_maps import elu_feature_map
17
+
18
+
19
+ def causal_linear(Q, K, V):
20
+ Q = Q.permute(0,2,1,3).contiguous()
21
+ K = K.permute(0,2,1,3).contiguous()
22
+ V = V.permute(0,2,1,3).contiguous()
23
+ V_new = causal_dot_product(Q, K, V)
24
+ return V_new.permute(0,2,1,3).contiguous()
25
+
26
+
27
+ class CausalLinearAttention(Module):
28
+ """Implement causally masked attention using dot product of feature maps in
29
+ O(N D^2) complexity.
30
+
31
+ See fast_transformers.attention.linear_attention.LinearAttention for the
32
+ general concept of replacing the softmax with feature maps. In addition to
33
+ that, we also make use of the fact that causal masking is a triangular mask
34
+ which allows us to apply the masking and still compute the attention in O(N
35
+ D^2) complexity.
36
+
37
+ Arguments
38
+ ---------
39
+ feature_map: callable, a callable that applies the feature map to the
40
+ last dimension of a tensor (default: elu(x)+1)
41
+ eps: float, a small number to ensure the numerical stability of the
42
+ denominator (default: 1e-6)
43
+ event_dispatcher: str or EventDispatcher instance to be used by this
44
+ module for dispatching events (default: the default
45
+ global dispatcher)
46
+ """
47
+ def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
48
+ event_dispatcher=""):
49
+ super(CausalLinearAttention, self).__init__()
50
+ self.feature_map = (
51
+ feature_map(query_dimensions) if feature_map else
52
+ elu_feature_map(query_dimensions)
53
+ )
54
+ self.eps = eps
55
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
56
+
57
+ def _make_sizes_compatible(self, Q, K):
58
+ """Either slice or pad K in case that the sizes do not match between Q
59
+ and K."""
60
+ N, L, H, E = Q.shape
61
+ _, S, _, _ = K.shape
62
+ if L == S:
63
+ return Q, K
64
+
65
+ if L < S:
66
+ return Q, K[:, :L, :, :]
67
+
68
+ if L > S:
69
+ return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1)
70
+
71
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
72
+ key_lengths):
73
+ # Apply the feature map to the queries and keys
74
+ self.feature_map.new_feature_map(queries.device)
75
+ Q = self.feature_map.forward_queries(queries)
76
+ K = self.feature_map.forward_keys(keys)
77
+
78
+ # Apply the key padding mask and make sure the attn_mask is a
79
+ # lower triangular causal mask
80
+ if not attn_mask.lower_triangular:
81
+ raise RuntimeError(("CausalLinearAttention only supports full "
82
+ "lower triangular masks"))
83
+ K = K * key_lengths.float_matrix[:, :, None, None]
84
+
85
+ # Ensure that Q and K have compatible sizes for the following
86
+ # computations, namely L == S
87
+ Q, K = self._make_sizes_compatible(Q, K)
88
+
89
+ # TODO: Shall we divide the Q and K with a relatively large number to
90
+ # avoid numerical instabilities in computing the denominator?
91
+ # We used to divide each with the max norm of all q and k but
92
+ # that seems relatively costly for a simple normalization.
93
+
94
+ # Compute the normalizers
95
+ Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
96
+
97
+ # Compute the unnormalized result
98
+ V = causal_linear(
99
+ Q,
100
+ K,
101
+ values
102
+ )
103
+
104
+ return V * Z[:, :, :, None]
105
+
106
+
107
+ # Register the attention implementation so that it becomes available in our
108
+ # builders
109
+ AttentionRegistry.register(
110
+ "causal-linear", CausalLinearAttention,
111
+ [
112
+ ("query_dimensions", Int),
113
+ ("feature_map", Optional(Callable)),
114
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
115
+ ]
116
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement clustered self attention."""
8
+
9
+ from math import sqrt
10
+
11
+ import torch
12
+ import torch.autograd
13
+ from torch.nn import Dropout, Module
14
+ from torch.nn.init import normal_
15
+
16
+ from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
17
+ Bool, EventDispatcherInstance
18
+ from ..events import EventDispatcher
19
+ from ..masking import FullMask
20
+ from ..aggregate import clustered_aggregate, clustered_broadcast
21
+ from ..clustering.hamming import cluster
22
+ from ..hashing import compute_hashes
23
+
24
+
25
+ class _GroupQueries(torch.autograd.Function):
26
+ @staticmethod
27
+ def forward(ctx, Q, clusters, counts, lengths):
28
+ factors = 1./counts.float()
29
+ q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
30
+ ctx.save_for_backward(clusters, counts, factors)
31
+
32
+ return q_grouped
33
+
34
+ @staticmethod
35
+ def backward(ctx, grad_q_grouped):
36
+ clusters, counts, factors = ctx.saved_tensors
37
+ grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
38
+
39
+ return grad_q, None, None, None
40
+
41
+
42
+ class _BroadcastValues(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, v_grouped, clusters, counts, lengths):
45
+ factors = torch.ones_like(counts, dtype=v_grouped.dtype)
46
+ V = clustered_broadcast(v_grouped, clusters, counts, factors)
47
+ ctx.save_for_backward(clusters, counts, factors, lengths)
48
+
49
+ return V
50
+
51
+ @staticmethod
52
+ def backward(ctx, grad_v):
53
+ clusters, counts, factors, lengths = ctx.saved_tensors
54
+ grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
55
+
56
+ return grad_v_grouped, None, None, None
57
+
58
+
59
+ class ClusteredAttention(Module):
60
+ """Use LSH and clustering in the resulting Hamming space to group queries
61
+ that will have minimal L2 distance from each other.
62
+
63
+ Given the queries, keys, and values as Q, K, and V respectively, we
64
+ first cluster the queries in "C" groups and compute the "C" query centroids
65
+ Q_c.
66
+
67
+ We now use to the centroids Q_c to compute the attention using:
68
+
69
+ V'_c = softmax(Q_c.mm(K.t()), dim=-1).mm(V).
70
+
71
+ Now the computed values V'_c are "broadcasted" back to the query members
72
+ of the corresponding cluster.
73
+
74
+ Arguments
75
+ ---------
76
+ clusters: How many clusters to group the queries into
77
+ iterations: The number of lloyd iterations to perform (default: 10)
78
+ bits: How many bits to use for the hash (default: 32)
79
+ hash_bias: If true, hamming distance proportional to L2 distance
80
+ If false, hamming distance proportional to cosine distance
81
+ (default: True)
82
+ softmax_temp: The temperature to use for the softmax attention.
83
+ (default: 1/sqrt(d_keys) where d_keys is computed at
84
+ runtime)
85
+ attention_dropout: The dropout rate to apply to the attention
86
+ (default: 0.1)
87
+ event_dispatcher: str or EventDispatcher instance to be used by this
88
+ module for dispatching events (default: the default
89
+ global dispatcher)
90
+ """
91
+ def __init__(self, clusters, iterations=10, bits=32,
92
+ hash_bias=True, softmax_temp=None, attention_dropout=0.1,
93
+ event_dispatcher=""):
94
+ super(ClusteredAttention, self).__init__()
95
+ self.clusters = clusters
96
+ self.iterations = iterations
97
+ self.bits = bits
98
+ self.hash_bias = hash_bias
99
+ self.softmax_temp = softmax_temp
100
+ self.dropout = Dropout(attention_dropout)
101
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
102
+
103
+ def _create_query_groups(self, Q, query_lengths):
104
+ N, H, L, E = Q.shape
105
+
106
+ # Compute the hashes for all the queries
107
+ planes = Q.new_empty((self.bits, E+1))
108
+ normal_(planes)
109
+ if not self.hash_bias:
110
+ planes[:, -1] = 0
111
+ hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
112
+
113
+ # Cluster the hashes and return the cluster index per query
114
+ clusters, counts = cluster(
115
+ hashes,
116
+ query_lengths._lengths.int(),
117
+ clusters=self.clusters,
118
+ iterations=self.iterations,
119
+ bits=self.bits
120
+ )
121
+ sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
122
+ return (sorted_clusters, counts), sorted_indx
123
+
124
+ def _group_queries(self, Q, groups, lengths):
125
+ """Aggregate the Qs based on the index of cluster they belong to. Make
126
+ sure to allow for gradient propagation backwards from the grouped
127
+ queries to each query."""
128
+ q_grouped = _GroupQueries.apply(Q, *groups, lengths)
129
+ return q_grouped
130
+
131
+ def _broadcast_values(self, V, groups, lengths):
132
+ """Broadcast the values back to the correct positions but make sure
133
+ that the gradient flows properly."""
134
+ V_new = _BroadcastValues.apply(V.contiguous(), *groups, lengths)
135
+ return V_new
136
+
137
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
138
+ key_lengths):
139
+ # Make sure that there is no attention mask
140
+ assert attn_mask.all_ones, ("Clustered attention cannot use an "
141
+ "arbitrary attention mask.")
142
+
143
+ queries = queries.permute(0,2,1,3).contiguous()
144
+ keys = keys.permute(0,2,1,3).contiguous()
145
+ values = values.permute(0,2,1,3).contiguous()
146
+
147
+ N, H, L, E = queries.shape
148
+ _, _, S, D = values.shape
149
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
150
+
151
+ # Cluster the queries into groups
152
+ groups, sorted_indx = self._create_query_groups(queries, query_lengths)
153
+ # Re-organize queries so that first group belong to first cluster
154
+ # next to second cluster and so on. This improves kernel implementations.
155
+ # Note that this step is introduced after NeurIPS submission and
156
+ # now the complexity is O(N log(N)).
157
+ q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
158
+ q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
159
+ s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
160
+
161
+ # Aggregate the re-arranged queries.
162
+ Q_grouped = self._group_queries(s_queries, groups, query_lengths._lengths.int())
163
+ # Compute the attention
164
+ QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
165
+ QK = QK + key_lengths.additive_matrix[:, None, None, :]
166
+ A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
167
+ V = torch.einsum("nhls,nhsd->nhld", A, values)
168
+
169
+ # Broadcast grouped attention
170
+ V_broadcast = self._broadcast_values(V, groups, query_lengths._lengths.int())
171
+
172
+ # Reverse the previous mapping
173
+ rev_indx = torch.argsort(sorted_indx, dim=-1)
174
+ q_rev_flat = (rev_indx.view(N*H, -1) + q_offset).reshape(-1)
175
+ V_new = V_broadcast.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
176
+ V_new = V_new.permute(0, 2, 1, 3).contiguous()
177
+ return V_new
178
+
179
+
180
+
181
+
182
+ # Register the attention implementation so that it becomes available in our
183
+ # builders
184
+ AttentionRegistry.register(
185
+ "clustered", ClusteredAttention,
186
+ [
187
+ ("clusters", Int),
188
+ ("iterations", Optional(Int, 10)),
189
+ ("bits", Optional(Int, 63)),
190
+ ("hash_bias", Optional(Bool, True)),
191
+ ("softmax_temp", Optional(Float)),
192
+ ("attention_dropout", Optional(Float, 0.1)),
193
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
194
+ ]
195
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement a self attention that delegates to full attention or another
8
+ attention depending on the input sequence length."""
9
+
10
+ import torch
11
+ from torch.nn import Module
12
+
13
+ from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
14
+ EventDispatcherInstance
15
+ from ..events import EventDispatcher
16
+ from .full_attention import FullAttention
17
+
18
+
19
+ class ConditionalFullAttention(Module):
20
+ """"Delegate to full attention if the input sequence is short.
21
+
22
+ Arguments
23
+ ---------
24
+ other_attention: Use the passed attention module if the sequence is
25
+ longer than 'length_limit'.
26
+ length_limit: An integer denoting the maximum sequence length to
27
+ consider.
28
+ softmax_temp: See fast_transformers.attention.full_attention.
29
+ attention_dropout: See fast_transformers.attention.full_attention.
30
+ event_dispatcher: str or EventDispatcher instance to be used by this
31
+ module for dispatching events (default: the default
32
+ global dispatcher)
33
+ """
34
+ def __init__(self, other_attention, length_limit=512, softmax_temp=None,
35
+ attention_dropout=0.1, event_dispatcher=""):
36
+ super(ConditionalFullAttention, self).__init__()
37
+ self.full_attention = FullAttention(softmax_temp, attention_dropout)
38
+ self.other_attention = other_attention
39
+ self.length_limit = length_limit
40
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
41
+
42
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
43
+ key_lengths):
44
+ # Extract some shapes to compare with the length limit
45
+ L = queries.shape[1]
46
+ S = values.shape[1]
47
+
48
+ if L > self.length_limit or S > self.length_limit:
49
+ return self.other_attention(queries, keys, values, attn_mask,
50
+ query_lengths, key_lengths)
51
+ else:
52
+ return self.full_attention(queries, keys, values, attn_mask,
53
+ query_lengths, key_lengths)
54
+
55
+
56
+ # Register the attention implementation so that it becomes available in our
57
+ # builders
58
+ AttentionRegistry.register(
59
+ "conditional-full", ConditionalFullAttention,
60
+ [
61
+ ("length_limit", Optional(Int, 512)),
62
+ ("softmax_temp", Optional(Float)),
63
+ ("attention_dropout", Optional(Float, 0.1)),
64
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
65
+ ]
66
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement the oracle top-k attention. The top-k keys are exact ones.
8
+ MultiHeadAttention module. Note that this module is to be used in conjuction
9
+ with the AttentionLayer in order to work."""
10
+
11
+ from math import sqrt
12
+
13
+ import torch
14
+ from torch.nn import Dropout, Module
15
+
16
+ from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
17
+ EventDispatcherInstance
18
+ from ..events import EventDispatcher
19
+
20
+
21
+ class ExactTopKAttention(Module):
22
+ """Implement the oracle top-k softmax attention.
23
+
24
+ Arguments
25
+ ---------
26
+ top-k: The top k keys to attend to (default: 32)
27
+ softmax_temp: The temperature to use for the softmax attention.
28
+ (default: 1/sqrt(d_keys) where d_keys is computed at
29
+ runtime)
30
+ attention_dropout: The dropout rate to apply to the attention
31
+ (default: 0.1)
32
+ event_dispatcher: str or EventDispatcher instance to be used by this
33
+ module for dispatching events (default: the default
34
+ global dispatcher)
35
+ """
36
+ def __init__(self, topk=32, softmax_temp=None, attention_dropout=0.1,
37
+ event_dispatcher=""):
38
+ super(ExactTopKAttention, self).__init__()
39
+ self.topk = topk
40
+ self.softmax_temp = softmax_temp
41
+ self.dropout = Dropout(attention_dropout)
42
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
43
+
44
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
45
+ key_lengths):
46
+ # Extract some shapes and compute the temperature
47
+ N, L, H, E = queries.shape
48
+ _, S, _, D = values.shape
49
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
50
+
51
+ # Compute the unnormalized attention and apply the masks
52
+ QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
53
+ topk = min(self.topk, S)
54
+
55
+ if not attn_mask.all_ones:
56
+ QK = QK + attn_mask.additive_matrix
57
+ QK = QK + key_lengths.additive_matrix[:, None, None]
58
+
59
+ topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1)
60
+ mask = QK.new_ones(QK.shape) * float("-inf")
61
+ mask[
62
+ torch.arange(N, device=QK.device).view(N, 1, 1, 1),
63
+ torch.arange(H, device=QK.device).view(1, H, 1, 1),
64
+ torch.arange(L, device=QK.device).view(1, 1, L, 1),
65
+ topk_idx,
66
+ ] = 0.
67
+
68
+ QK = QK + mask
69
+
70
+ # Compute the attention and the weighted average
71
+ A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
72
+ V = torch.einsum("nhls,nshd->nlhd", A, values)
73
+
74
+ # Make sure that what we return is contiguous
75
+ return V.contiguous()
76
+
77
+
78
+ # Register the attention implementation so that it becomes available in our
79
+ # builders
80
+ AttentionRegistry.register(
81
+ "exact-topk", ExactTopKAttention,
82
+ [
83
+ ("topk", Optional(Int, 32)),
84
+ ("softmax_temp", Optional(Float)),
85
+ ("attention_dropout", Optional(Float, 0.1)),
86
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
87
+ ]
88
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement the full attention similar to the one implemented by PyTorch's
8
+ MultiHeadAttention module. Note that this module is to be used in conjuction
9
+ with the `fast_transformers.attention.attention_layer.AttentionLayer` in order
10
+ to work."""
11
+
12
+ from math import sqrt
13
+
14
+ import torch
15
+ from torch.nn import Dropout, Module
16
+
17
+ from ..attention_registry import AttentionRegistry, Optional, Float, \
18
+ EventDispatcherInstance
19
+ from ..events import EventDispatcher, AttentionEvent
20
+
21
+
22
+ class FullAttention(Module):
23
+ """Implement the scaled dot product attention with softmax.
24
+
25
+ Arguments
26
+ ---------
27
+ softmax_temp: The temperature to use for the softmax attention.
28
+ (default: 1/sqrt(d_keys) where d_keys is computed at
29
+ runtime)
30
+ attention_dropout: The dropout rate to apply to the attention
31
+ (default: 0.1)
32
+ event_dispatcher: str or EventDispatcher instance to be used by this
33
+ module for dispatching events (default: the default
34
+ global dispatcher)
35
+ """
36
+ def __init__(self, softmax_temp=None, attention_dropout=0.1,
37
+ event_dispatcher=""):
38
+ super(FullAttention, self).__init__()
39
+ self.softmax_temp = softmax_temp
40
+ self.dropout = Dropout(attention_dropout)
41
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
42
+
43
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
44
+ key_lengths):
45
+ """Implements the multihead softmax attention.
46
+
47
+ Arguments
48
+ ---------
49
+ queries: (N, L, H, E) The tensor containing the queries
50
+ keys: (N, S, H, E) The tensor containing the keys
51
+ values: (N, S, H, D) The tensor containing the values
52
+ attn_mask: An implementation of BaseMask that encodes where each
53
+ query can attend to
54
+ query_lengths: An implementation of BaseMask that encodes how
55
+ many queries each sequence in the batch consists of
56
+ key_lengths: An implementation of BaseMask that encodes how
57
+ many queries each sequence in the batch consists of
58
+ """
59
+ # Extract some shapes and compute the temperature
60
+ N, L, H, E = queries.shape
61
+ _, S, _, D = values.shape
62
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
63
+
64
+ # Scale the queries instead of applying the softmax temperature to the
65
+ # dot products
66
+ queries = queries * softmax_temp
67
+
68
+ # Compute the unnormalized attention and apply the masks
69
+ QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
70
+ if not attn_mask.all_ones:
71
+ QK = QK + attn_mask.additive_matrix
72
+ if not key_lengths.all_ones:
73
+ QK = QK + key_lengths.additive_matrix[:, None, None]
74
+
75
+ # Compute the attention and the weighted average
76
+ A = self.dropout(torch.softmax(QK, dim=-1))
77
+ V = torch.einsum("nhls,nshd->nlhd", A, values)
78
+
79
+ # Let the world know of the attention matrix
80
+ self.event_dispatcher.dispatch(AttentionEvent(self, A))
81
+
82
+ # Make sure that what we return is contiguous
83
+ return V.contiguous()
84
+
85
+
86
+ # Register the attention implementation so that it becomes available in our
87
+ # builders
88
+ AttentionRegistry.register(
89
+ "full", FullAttention,
90
+ [
91
+ ("softmax_temp", Optional(Float)),
92
+ ("attention_dropout", Optional(Float, 0.1)),
93
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
94
+ ]
95
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement improved clustered self attention."""
8
+
9
+ from math import sqrt
10
+
11
+ import torch
12
+ import torch.autograd
13
+ from torch.nn import Dropout, Module
14
+ from torch.nn.init import normal_
15
+
16
+ from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
17
+ Bool, EventDispatcherInstance
18
+ from ..events import EventDispatcher
19
+ from ..masking import FullMask
20
+ from ..aggregate import clustered_aggregate, clustered_broadcast
21
+ from ..clustering.hamming import cluster
22
+ from ..hashing import compute_hashes
23
+ from ..sparse_product import sparse_dot_product, sparse_weighted_average
24
+ from ..sparse_product import clustered_sparse_dot_product, \
25
+ clustered_sparse_weighted_average
26
+
27
+
28
+ class _GroupQueries(torch.autograd.Function):
29
+ @staticmethod
30
+ def forward(ctx, Q, clusters, counts, lengths):
31
+ factors = 1./counts.float()
32
+ q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
33
+ ctx.save_for_backward(clusters, counts, factors)
34
+
35
+ return q_grouped
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_q_grouped):
39
+ clusters, counts, factors = ctx.saved_tensors
40
+ grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
41
+
42
+ return grad_q, None, None, None
43
+
44
+
45
+ class _BroadcastValues(torch.autograd.Function):
46
+ @staticmethod
47
+ def forward(ctx, v_grouped, clusters, counts, lengths):
48
+ factors = torch.ones_like(counts, dtype=v_grouped.dtype)
49
+ V = clustered_broadcast(v_grouped, clusters, counts, factors)
50
+ ctx.save_for_backward(clusters, counts, factors, lengths)
51
+
52
+ return V
53
+
54
+ @staticmethod
55
+ def backward(ctx, grad_v):
56
+ clusters, counts, factors, lengths = ctx.saved_tensors
57
+ grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
58
+
59
+ return grad_v_grouped, None, None, None, None
60
+
61
+
62
+ class ImprovedClusteredAttention(Module):
63
+ """
64
+ Immproved clustered attention approximation by recompution attention
65
+ for each query with the top-k keys for the corresponding cluster.
66
+
67
+ Given the queries, keys, and values as Q, K, and V respectively, we
68
+ first cluster the queries in "C" groups and compute the "C" query centroids
69
+ Q_c.
70
+
71
+ We now use to the centroids Q_c to identify the top-k keys with highest
72
+ dot products.
73
+
74
+ Subsequently, for each query we compute the sparse dot product with
75
+ the corresponding top-k keys to improve the attention approximation.
76
+
77
+ Arguments
78
+ ---------
79
+ clusters: How many clusters to group the queries into
80
+ iterations: The number of lloyd iterations to perform (default: 10)
81
+ bits: How many bits to use for the hash (default: 32)
82
+ hash_bias: If true, hamming distance proportional to L2 distance
83
+ If false, hamming distance proportional to cosine distance
84
+ (default: True)
85
+ topk: Number of top-k keys to for improved approximation (default: 32)
86
+ softmax_temp: The temperature to use for the softmax attention.
87
+ (default: 1/sqrt(d_keys) where d_keys is computed at
88
+ runtime)
89
+ attention_dropout: The dropout rate to apply to the attention
90
+ (default: 0.1)
91
+ event_dispatcher: str or EventDispatcher instance to be used by this
92
+ module for dispatching events (default: the default
93
+ global dispatcher)
94
+ """
95
+ def __init__(self, clusters, iterations=10, bits=32,
96
+ hash_bias=True, topk=32, softmax_temp=None,
97
+ attention_dropout=0.1, event_dispatcher=""):
98
+ super(ImprovedClusteredAttention, self).__init__()
99
+ self.clusters = clusters
100
+ self.iterations = iterations
101
+ self.bits = bits
102
+ self.hash_bias = hash_bias
103
+ self.topk = topk
104
+ self.softmax_temp = softmax_temp
105
+ self.dropout = Dropout(attention_dropout)
106
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
107
+
108
+ def _create_query_groups(self, Q, query_lengths):
109
+ N, H, L, E = Q.shape
110
+
111
+ # Compute the hashes for all the queries
112
+ planes = Q.new_empty((self.bits, E+1))
113
+ normal_(planes)
114
+ if not self.hash_bias:
115
+ planes[:, -1] = 0
116
+ hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
117
+
118
+ # Cluster the hashes and return the cluster index per query
119
+ clusters, counts = cluster(
120
+ hashes,
121
+ query_lengths._lengths.int(),
122
+ clusters=self.clusters,
123
+ iterations=self.iterations,
124
+ bits=self.bits
125
+ )
126
+ sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
127
+ return (sorted_clusters, counts), sorted_indx
128
+
129
+ def _topk_attention(self, Q, K, V,
130
+ clusters, counts,
131
+ topk, topk_values,
132
+ A_bottomk, softmax_temp,
133
+ query_lengths):
134
+ """Return the attention with just the topk heads."""
135
+ # Extract some indices
136
+ N, H, L, E = Q.shape
137
+ _, _, S, _ = K.shape
138
+ _, _, C, k = topk.shape
139
+
140
+ # We need to pass the output tensor to initialize to 0
141
+ QK = clustered_sparse_dot_product(
142
+ Q, K, topk,
143
+ clusters, counts,
144
+ query_lengths._lengths.int()
145
+ )
146
+ # We need to mask the topk dot products if topk > input_length
147
+ QK = QK.masked_fill(
148
+ torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k),
149
+ float("-inf")
150
+ )
151
+ A = torch.softmax(softmax_temp * QK, dim=-1)
152
+ assert A_bottomk.is_contiguous()
153
+ A_bottomk = clustered_broadcast(
154
+ A_bottomk.unsqueeze(3),
155
+ clusters,
156
+ counts,
157
+ torch.ones_like(counts, dtype=torch.float32)
158
+ )
159
+ A = A * (1.0 - A_bottomk)
160
+ A = self.dropout(A)
161
+ assert A.is_contiguous()
162
+ V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts)
163
+
164
+ return V_new
165
+
166
+ def _broadcast_values(self, V, clusters, counts, lengths):
167
+ """Broadcast the values back to the correct positions but make sure
168
+ that the gradient flows properly."""
169
+ V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths)
170
+ return V_new
171
+
172
+ def _bottomk_attention(self, QK, V, clusters, counts, query_lengths, topk, softmax_temp):
173
+ """Return the attention with just the bottomk keys."""
174
+ N, H, C, S = QK.shape
175
+
176
+ A = torch.softmax(softmax_temp * QK, dim=-1)
177
+ mask = QK.new_ones(QK.shape)
178
+ mask[
179
+ torch.arange(N, device=QK.device).view(N, 1, 1, 1),
180
+ torch.arange(H, device=QK.device).view(1, H, 1, 1),
181
+ torch.arange(C, device=QK.device).view(1, 1, C, 1),
182
+ topk,
183
+ ] = 0
184
+ A = A * mask
185
+ A_bottomk = A.sum(-1)
186
+ A = self.dropout(A)
187
+ # Compute the values
188
+ V_new = torch.einsum("nhls,nhse->nhle", A, V)
189
+ # Broadcast the values back depending on the groups
190
+ V_new = self._broadcast_values(V_new, clusters, counts, query_lengths._lengths.int())
191
+
192
+ return V_new, A_bottomk
193
+
194
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
195
+ key_lengths):
196
+ # Make sure that there is no attention mask
197
+ assert attn_mask.all_ones, ("Improved-clustered attention cannot "
198
+ "use an arbitrary attention mask.")
199
+
200
+ queries = queries.permute(0,2,1,3).contiguous()
201
+ keys = keys.permute(0,2,1,3).contiguous()
202
+ values = values.permute(0,2,1,3).contiguous()
203
+ N, H, L, E = queries.shape
204
+ _, _, S, D = values.shape
205
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
206
+
207
+ # Cluster the queries into groups
208
+ groups, sorted_indx = self._create_query_groups(queries, query_lengths)
209
+ clusters, counts = groups
210
+
211
+ # Re-organize queries so that first group belong to first cluster
212
+ # next to second cluster and so on. This improves kernel implementations.
213
+ # Note that this step is introduced after NeurIPS submission and
214
+ # now the complexity is O(N log(N)).
215
+ q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
216
+ q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
217
+ s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
218
+
219
+ # Aggregate the re-arranged queries.
220
+ Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int())
221
+ # Compute the attention
222
+ QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
223
+ QK = QK + key_lengths.additive_matrix[:, None, None, :]
224
+ topk_values, topk = torch.topk(QK, min(self.topk, S), sorted=False, dim=-1)
225
+ assert topk.is_contiguous()
226
+
227
+ # Now compute the attention with only the bottom keys
228
+ V_bottomk, A_bottomk = self._bottomk_attention(
229
+ QK, values,
230
+ clusters, counts,
231
+ query_lengths,
232
+ topk,
233
+ softmax_temp
234
+ )
235
+
236
+ # Now compute the attention with only the top keys
237
+ V_topk = self._topk_attention(
238
+ s_queries, keys, values,
239
+ clusters, counts,
240
+ topk, topk_values,
241
+ A_bottomk,
242
+ softmax_temp,
243
+ query_lengths
244
+ )
245
+ V_sorted_new = V_topk + V_bottomk
246
+
247
+ # Reverse the previous mapping
248
+ sorted_rev_indx = torch.argsort(sorted_indx, dim=-1)
249
+ q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1)
250
+ V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
251
+ return V_new.permute(0, 2, 1, 3).contiguous()
252
+
253
+
254
+ # Register the attention implementation so that it becomes available in our
255
+ # builders
256
+ AttentionRegistry.register(
257
+ "improved-clustered", ImprovedClusteredAttention,
258
+ [
259
+ ("clusters", Int),
260
+ ("iterations", Optional(Int, 10)),
261
+ ("bits", Optional(Int, 63)),
262
+ ("hash_bias", Optional(Bool, True)),
263
+ ("topk", Optional(Int, 32)),
264
+ ("softmax_temp", Optional(Float)),
265
+ ("attention_dropout", Optional(Float, 0.1)),
266
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
267
+ ]
268
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement improved clustered causal self attention."""
8
+
9
+ from math import sqrt
10
+
11
+ import torch
12
+ import torch.autograd
13
+ from torch.nn import Dropout, Module
14
+ from torch.nn.init import normal_
15
+
16
+ from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
17
+ Bool, EventDispatcherInstance
18
+ from ..events import EventDispatcher
19
+ from ..masking import FullMask
20
+ from ..aggregate import clustered_aggregate, clustered_broadcast
21
+ from ..clustering.hamming import cluster
22
+ from ..hashing import compute_hashes
23
+ from ..sparse_product import sparse_dot_product, sparse_weighted_average
24
+ from ..sparse_product import clustered_sparse_dot_product, \
25
+ clustered_sparse_weighted_average
26
+
27
+
28
+ class _GroupQueries(torch.autograd.Function):
29
+ @staticmethod
30
+ def forward(ctx, Q, clusters, counts, lengths):
31
+ factors = 1./counts.float()
32
+ q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
33
+ ctx.save_for_backward(clusters, counts, factors)
34
+
35
+ return q_grouped
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_q_grouped):
39
+ clusters, counts, factors = ctx.saved_tensors
40
+ grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
41
+
42
+ return grad_q, None, None, None
43
+
44
+
45
+ class _BroadcastValues(torch.autograd.Function):
46
+ @staticmethod
47
+ def forward(ctx, v_grouped, clusters, counts, lengths):
48
+ factors = torch.ones_like(counts, dtype=v_grouped.dtype)
49
+ V = clustered_broadcast(v_grouped, clusters, counts, factors)
50
+ ctx.save_for_backward(clusters, counts, factors, lengths)
51
+
52
+ return V
53
+
54
+ @staticmethod
55
+ def backward(ctx, grad_v):
56
+ clusters, counts, factors, lengths = ctx.saved_tensors
57
+ grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
58
+
59
+ return grad_v_grouped, None, None, None, None
60
+
61
+
62
+ class ImprovedClusteredCausalAttention(Module):
63
+ """
64
+ Immproved clustered causal attention approximation by recomputing attention
65
+ for each query with the top-k keys for the corresponding cluster.
66
+
67
+ Given the queries, keys, and values as Q, K, and V respectively, we
68
+ first cluster the queries in "C" groups and compute the "C" query centroids
69
+ Q_c.
70
+
71
+ We now use to the centroids Q_c to identify the top-k keys with highest
72
+ dot products.
73
+
74
+ Subsequently, for each query we compute the sparse dot product with
75
+ the corresponding top-k keys to improve the attention approximation.
76
+
77
+ Key difference with improved clustered attention is that we only use
78
+ top-k keys with causal mask, we do not compute attention on the
79
+ bottom-k keys.
80
+
81
+ Arguments
82
+ ---------
83
+ clusters: How many clusters to group the queries into
84
+ iterations: The number of lloyd iterations to perform (default: 10)
85
+ bits: How many bits to use for the hash (default: 32)
86
+ hash_bias: If true, hamming distance proportional to L2 distance
87
+ If false, hamming distance proportional to cosine distance
88
+ (default: True)
89
+ topk: Number of top-k keys to for improved approximation (default: 32)
90
+ softmax_temp: The temperature to use for the softmax attention.
91
+ (default: 1/sqrt(d_keys) where d_keys is computed at
92
+ runtime)
93
+ attention_dropout: The dropout rate to apply to the attention
94
+ (default: 0.1)
95
+ event_dispatcher: str or EventDispatcher instance to be used by this
96
+ module for dispatching events (default: the default
97
+ global dispatcher)
98
+ """
99
+ def __init__(self, clusters, iterations=10, bits=32,
100
+ hash_bias=True, topk=32, softmax_temp=None,
101
+ attention_dropout=0.1, event_dispatcher=""):
102
+ super(ImprovedClusteredCausalAttention, self).__init__()
103
+ self.clusters = clusters
104
+ self.iterations = iterations
105
+ self.bits = bits
106
+ self.hash_bias = hash_bias
107
+ self.topk = topk
108
+ self.softmax_temp = softmax_temp
109
+ self.dropout = Dropout(attention_dropout)
110
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
111
+
112
+ def _create_query_groups(self, Q, query_lengths):
113
+ N, H, L, E = Q.shape
114
+
115
+ # Compute the hashes for all the queries
116
+ planes = Q.new_empty((self.bits, E+1))
117
+ normal_(planes)
118
+ if not self.hash_bias:
119
+ planes[:, -1] = 0
120
+ hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
121
+
122
+ # Cluster the hashes and return the cluster index per query
123
+ clusters, counts = cluster(
124
+ hashes,
125
+ query_lengths.lengths.int(),
126
+ clusters=self.clusters,
127
+ iterations=self.iterations,
128
+ bits=self.bits
129
+ )
130
+ sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
131
+ return (sorted_clusters, counts), sorted_indx
132
+
133
+ def _topk_attention(self, Q, K, V,
134
+ q_flat, q_rev_flat,
135
+ clusters, counts,
136
+ topk, topk_values,
137
+ softmax_temp,
138
+ query_lengths):
139
+ """Return the attention with just the topk heads."""
140
+ # Extract some indices
141
+ N, H, L, E = Q.shape
142
+ _, _, S, _ = K.shape
143
+ _, _, C, k = topk.shape
144
+
145
+ # We need to pass the output tensor to initialize to 0
146
+ QK = clustered_sparse_dot_product(
147
+ Q, K, topk,
148
+ clusters, counts,
149
+ query_lengths.lengths.int()
150
+ )
151
+ # We need to mask out the future
152
+ assert topk.is_contiguous()
153
+ topk_broadcast = clustered_broadcast(
154
+ topk.float(),
155
+ clusters,
156
+ counts,
157
+ torch.ones_like(counts, dtype=torch.float32)
158
+ )
159
+ # Need to be careful here we changed the order of the keys the
160
+ # masking on future needs to be applied in the same way
161
+ seq_ids = torch.arange(L, device=QK.device).view(1, 1, L, 1).repeat(N, H, 1, 1)
162
+ # permute the ids in the same way as input so as to mask the right
163
+ # entries for each query
164
+ s_seq_ids = seq_ids.reshape(-1, 1).index_select(0, q_flat).view(N,H,L,1)
165
+ future_mask = topk_broadcast.long() > s_seq_ids
166
+ QK = QK.masked_fill(
167
+ future_mask,
168
+ float("-1e7")
169
+ )
170
+ A = torch.softmax(softmax_temp * QK, dim=-1)
171
+ # Mask again to ensure no probabilities leak due to float(-1e7)
172
+ # Leakage could be very high as we use a small top-k
173
+ A = A * (1. - future_mask.float())
174
+ A = self.dropout(A)
175
+ assert A.is_contiguous()
176
+ V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts)
177
+
178
+ return V_new
179
+
180
+ def _broadcast_values(self, V, clusters, counts, lengths):
181
+ """Broadcast the values back to the correct positions but make sure
182
+ that the gradient flows properly."""
183
+ V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths)
184
+ return V_new
185
+
186
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
187
+ key_lengths):
188
+
189
+ # Apply the key padding mask and make sure the attn_mask is a
190
+ # lower triangular causal mask
191
+ if not attn_mask.lower_triangular:
192
+ raise RuntimeError(("ImprovedClusteredCausalAttention only supports "
193
+ "lower triangular masks"))
194
+ queries = queries.permute(0,2,1,3).contiguous()
195
+ keys = keys.permute(0,2,1,3).contiguous()
196
+ values = values.permute(0,2,1,3).contiguous()
197
+ N, H, L, E = queries.shape
198
+ _, _, S, D = values.shape
199
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
200
+
201
+ # Cluster the queries into groups
202
+ groups, sorted_indx = self._create_query_groups(queries, query_lengths)
203
+ clusters, counts = groups
204
+
205
+ # Re-organize queries so that first group belong to first cluster
206
+ # next to second cluster and so on. This improves kernel implementations.
207
+ # Note that this step is introduced after NeurIPS submission and
208
+ # now the complexity is O(N log(N)).
209
+ q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
210
+ q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
211
+ s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
212
+
213
+ # Aggregate the re-arranged queries.
214
+ Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int())
215
+ # Compute the attention
216
+ QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
217
+ QK = QK + key_lengths.additive_matrix[:, None, None, :]
218
+ # Set topk to minimum of key lengths if it is smaller than self.topk
219
+ cur_topk = min(self.topk, min(key_lengths.lengths).item())
220
+ topk_values, topk = torch.topk(QK, cur_topk, sorted=False, dim=-1)
221
+ assert topk.is_contiguous()
222
+
223
+ # Reverse mapping
224
+ sorted_rev_indx = torch.argsort(sorted_indx, dim=-1)
225
+ q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1)
226
+
227
+ # Compute the attention with only the top keys
228
+ V_topk = self._topk_attention(
229
+ s_queries, keys, values,
230
+ q_flat, q_rev_flat,
231
+ clusters, counts,
232
+ topk, topk_values,
233
+ softmax_temp,
234
+ query_lengths
235
+ )
236
+ V_sorted_new = V_topk
237
+
238
+ # Reverse the mapping to get correct values
239
+ V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
240
+ return V_new.permute(0, 2, 1, 3).contiguous()
241
+
242
+
243
+ # Register the attention implementation so that it becomes available in our
244
+ # builders
245
+ AttentionRegistry.register(
246
+ "causal-improved-clustered", ImprovedClusteredCausalAttention,
247
+ [
248
+ ("clusters", Int),
249
+ ("iterations", Optional(Int, 10)),
250
+ ("bits", Optional(Int, 63)),
251
+ ("hash_bias", Optional(Bool, True)),
252
+ ("topk", Optional(Int, 32)),
253
+ ("softmax_temp", Optional(Float)),
254
+ ("attention_dropout", Optional(Float, 0.1)),
255
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
256
+ ]
257
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement unmasked linear attention."""
8
+
9
+ import torch
10
+ from torch.nn import Module
11
+
12
+ from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
13
+ EventDispatcherInstance
14
+ from ..events import EventDispatcher
15
+ from ..feature_maps import elu_feature_map
16
+
17
+
18
+ class LinearAttention(Module):
19
+ """Implement unmasked attention using dot product of feature maps in
20
+ O(N D^2) complexity.
21
+
22
+ Given the queries, keys and values as Q, K, V instead of computing
23
+
24
+ V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
25
+
26
+ we make use of a feature map function Φ(.) and perform the following
27
+ computation
28
+
29
+ V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
30
+
31
+ The above can be computed in O(N D^2) complexity where D is the
32
+ dimensionality of Q, K and V and N is the sequence length. Depending on the
33
+ feature map, however, the complexity of the attention might be limited.
34
+
35
+ Arguments
36
+ ---------
37
+ feature_map: callable, a callable that applies the feature map to the
38
+ last dimension of a tensor (default: elu(x)+1)
39
+ eps: float, a small number to ensure the numerical stability of the
40
+ denominator (default: 1e-6)
41
+ event_dispatcher: str or EventDispatcher instance to be used by this
42
+ module for dispatching events (default: the default
43
+ global dispatcher)
44
+ """
45
+ def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
46
+ event_dispatcher=""):
47
+ super(LinearAttention, self).__init__()
48
+ self.feature_map = (
49
+ feature_map(query_dimensions) if feature_map else
50
+ elu_feature_map(query_dimensions)
51
+ )
52
+ self.eps = eps
53
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
54
+
55
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
56
+ key_lengths):
57
+ # Apply the feature map to the queries and keys
58
+ self.feature_map.new_feature_map(queries.device)
59
+ Q = self.feature_map.forward_queries(queries)
60
+ K = self.feature_map.forward_keys(keys)
61
+
62
+ # Apply the key padding mask and make sure that the attn_mask is
63
+ # all_ones
64
+ if not attn_mask.all_ones:
65
+ raise RuntimeError(("LinearAttention does not support arbitrary "
66
+ "attention masks"))
67
+ K = K * key_lengths.float_matrix[:, :, None, None]
68
+
69
+ # Compute the KV matrix, namely the dot product of keys and values so
70
+ # that we never explicitly compute the attention matrix and thus
71
+ # decrease the complexity
72
+ KV = torch.einsum("nshd,nshm->nhmd", K, values)
73
+
74
+ # Compute the normalizer
75
+ Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
76
+
77
+ # Finally compute and return the new values
78
+ V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
79
+
80
+ return V.contiguous()
81
+
82
+
83
+ # Register the attention implementation so that it becomes available in our
84
+ # builders
85
+ AttentionRegistry.register(
86
+ "linear", LinearAttention,
87
+ [
88
+ ("query_dimensions", Int),
89
+ ("feature_map", Optional(Callable)),
90
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
91
+ ]
92
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """Implement local context attention."""
7
+
8
+ from math import sqrt
9
+
10
+ import torch
11
+ from torch.nn import Module, Dropout
12
+ from torch.nn import functional as F
13
+
14
+ from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
15
+ EventDispatcherInstance
16
+ from ..events import EventDispatcher
17
+ from ..local_product import local_dot_product, local_weighted_average
18
+
19
+
20
+ class LocalAttention(Module):
21
+ """Implement fast local attention where a query can only attend to
22
+ neighboring keys.
23
+
24
+ In this attention module the query Q_i can only attend to a key K_j if
25
+ |i-j| < local_context/2.
26
+
27
+ Arguments
28
+ ---------
29
+ local_context: The neighborhood to consider for local attention.
30
+ softmax_temp: The temperature to use for the softmax attention.
31
+ (default: 1/sqrt(d_keys) where d_keys is computed at
32
+ runtime)
33
+ attention_dropout: The dropout rate to apply to the attention
34
+ (default: 0.1)
35
+ event_dispatcher: str or EventDispatcher instance to be used by this
36
+ module for dispatching events (default: the default
37
+ global dispatcher)
38
+ """
39
+ def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1,
40
+ event_dispatcher=""):
41
+ super(LocalAttention, self).__init__()
42
+ self.local_context = local_context
43
+ self.softmax_temp = softmax_temp
44
+ self.dropout = Dropout(attention_dropout)
45
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
46
+
47
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
48
+ key_lengths):
49
+ """Implements the local attention.
50
+
51
+ The attn_mask can be anything but the only values that will be
52
+ considered will be the ones in the neighborhood of each query.
53
+
54
+ Arguments
55
+ ---------
56
+ queries: (N, L, H, E) The tensor containing the queries
57
+ keys: (N, S, H, E) The tensor containing the keys
58
+ values: (N, S, H, D) The tensor containing the values
59
+ attn_mask: An implementation of BaseMask that encodes where each
60
+ query can attend to
61
+ query_lengths: An implementation of BaseMask that encodes how
62
+ many queries each sequence in the batch consists of
63
+ key_lengths: An implementation of BaseMask that encodes how
64
+ many queries each sequence in the batch consists of
65
+ """
66
+ # Extract some shapes and compute the temperature
67
+ N, L, H, E = queries.shape
68
+ _, S, _, D = values.shape
69
+ context = self.local_context
70
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
71
+
72
+ # Permute the dimensions to NHLE instead of NLHE
73
+ queries = queries.permute(0, 2, 1, 3).contiguous()
74
+ keys = keys.permute(0, 2, 1, 3).contiguous()
75
+ values = values.permute(0, 2, 1, 3).contiguous()
76
+
77
+ QK = local_dot_product(
78
+ queries,
79
+ keys,
80
+ attn_mask.additive_matrix_finite,
81
+ key_lengths.lengths,
82
+ self.local_context
83
+ )
84
+ A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
85
+
86
+ V_new = local_weighted_average(A, values)
87
+
88
+ return V_new.permute(0, 2, 1, 3).contiguous()
89
+
90
+
91
+ # Register the attention implementation so that it becomes available in our
92
+ # builders
93
+ AttentionRegistry.register(
94
+ "local", LocalAttention,
95
+ [
96
+ ("local_context", Int),
97
+ ("softmax_temp", Optional(Float)),
98
+ ("attention_dropout", Optional(Float, 0.1)),
99
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
100
+ ]
101
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Implement the Reformer attention from the paper
8
+ "Reformer the efficient transformer"."""
9
+
10
+ from math import sqrt
11
+
12
+ import torch
13
+ from torch.nn import Dropout, Module
14
+ from torch.nn.init import normal_
15
+
16
+ from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
17
+ Bool, EventDispatcherInstance
18
+ from ..events import EventDispatcher
19
+ from ..masking import FullMask
20
+
21
+
22
+ class ReformerAttention(Module):
23
+ """Implement the attention module of the paper "Reformer the efficient
24
+ transformer"
25
+
26
+ Arguments
27
+ ---------
28
+ chunk_size : Chunk size for each block (default: 32)
29
+ bits : Number of bits for hashing (default: 8)
30
+ rounds : Number of rounds of attention computation (default: 4)
31
+ masked : If true, the query does not attend to itsself (default: False)
32
+ softmax_temp: The temperature to use for the softmax attention.
33
+ (default: 1/sqrt(d_keys) where d_keys is computed at
34
+ runtime)
35
+ attention_dropout: The dropout rate to apply to the attention
36
+ (default: 0.1)
37
+ event_dispatcher: str or EventDispatcher instance to be used by this
38
+ module for dispatching events (default: the default
39
+ global dispatcher)
40
+ """
41
+
42
+ def __init__(self, chunk_size=32, bits=8, rounds=4, masked=False,
43
+ softmax_temp=None, attention_dropout=0.1,
44
+ event_dispatcher=""):
45
+ super(ReformerAttention, self).__init__()
46
+
47
+ self.chunk_size = chunk_size
48
+ self.bits = bits
49
+ self.rounds = rounds
50
+ self.masked = masked
51
+ self.softmax_temp = softmax_temp
52
+ self.dropout = Dropout(attention_dropout)
53
+ self.event_dispatcher = EventDispatcher.get(event_dispatcher)
54
+
55
+ def _normalize(self, x):
56
+ norms = torch.sqrt(torch.einsum("nlhe,nlhe->nlh", x, x))
57
+ x_normed = x / norms.unsqueeze(-1)
58
+ return x_normed
59
+
60
+ def _look_back(self, x):
61
+ xshape = x.shape
62
+
63
+ return torch.cat([
64
+ x.new_zeros((xshape[0], 1) + xshape[2:]),
65
+ torch.repeat_interleave(x, 2, dim=1)[:,:-1]
66
+ ], dim=1).view(xshape[0], xshape[1], 2*xshape[2], *xshape[3:])
67
+
68
+ def _reformer_round(self, Q, K, V, mask, softmax_temp):
69
+ # Hash the queries
70
+ N, L, H, E = Q.shape
71
+ planes = Q.new_empty(self.bits, E)
72
+ normal_(planes)
73
+ projected = torch.einsum("nlhe,be->nlhb", K, planes)
74
+ hashes = torch.argmax(
75
+ torch.cat([projected, -projected], dim=-1),
76
+ dim=-1
77
+ )
78
+
79
+ # Sort the queries in order to group them
80
+ group = torch.argsort(hashes, dim=1)
81
+
82
+ invert_group = torch.empty_like(group)
83
+ batch_indices = torch.arange(N, device=hashes.device).view(N, 1, 1)
84
+ sequence_indices = torch.arange(L, device=hashes.device).view(1, L, 1)
85
+ head_indices = torch.arange(H, device=hashes.device).view(1, 1, H)
86
+ invert_group[batch_indices, group, head_indices] = sequence_indices
87
+ group = group.view(N, -1, self.chunk_size, H)
88
+ invert_group = invert_group.view(N, -1, self.chunk_size, H)
89
+ batch_indices = batch_indices.unsqueeze(1)
90
+ head_indices = head_indices.unsqueeze(0)
91
+
92
+ # Reorder Q, V and mask
93
+ Q_grouped = Q[batch_indices, group, head_indices]
94
+ K_grouped = K[batch_indices, group, head_indices]
95
+ V_grouped = V[batch_indices, group, head_indices]
96
+ mask_grouped = mask[
97
+ batch_indices.unsqueeze(1),
98
+ group.unsqueeze(3),
99
+ self._look_back(group).unsqueeze(2)
100
+ ]
101
+
102
+ mask_grouped[:, 0, :, :Q_grouped.shape[2]] = float("-inf")
103
+
104
+ # When everything is masked just unmask everything because it doesn't
105
+ # matter what the output is at those positions
106
+ # This is to avoid inf/nans in the new values at masked positions
107
+ infmask = torch.isinf(mask_grouped)
108
+ infmask = torch.all(infmask, dim=3, keepdims=True)
109
+ mask_grouped = mask_grouped.masked_fill(infmask, 0.)
110
+
111
+ # Attention
112
+ K_grouped = self._look_back(K_grouped)
113
+ QQ = torch.einsum("nblhe,nbshe->nbhls", Q_grouped, K_grouped)
114
+ QQ = QQ + mask_grouped.permute(0, 1, 4, 2, 3)
115
+ A = torch.softmax(softmax_temp * QQ, dim=-1)
116
+ A = self.dropout(A)
117
+
118
+ # Values
119
+ V_grouped = self._look_back(V_grouped)
120
+ V_new = torch.einsum("nbhls,nbshe->nblhe", A, V_grouped)
121
+ V_new = V_new.contiguous().view(N, -1, H, E)
122
+ V_new = V_new[batch_indices, invert_group, head_indices]
123
+ V_new = V_new.contiguous().view(N, L, H, E)
124
+ return V_new
125
+
126
+ def forward(self, queries, keys, values, attn_mask, query_lengths,
127
+ key_lengths):
128
+ # Extract the dimensions of query, key, value
129
+ N, L, H, E = queries.shape
130
+
131
+ softmax_temp = self.softmax_temp or 1./sqrt(E)
132
+ # Create the mask
133
+ mask = key_lengths.additive_matrix.unsqueeze(1).expand(N, L, L)
134
+ if self.masked:
135
+ mask = mask + torch.eye(L, device=queries.device).unsqueeze(0)*float(-1e9)
136
+
137
+ if not attn_mask.all_ones:
138
+ mask = mask + attn_mask.additive_matrix.unsqueeze(0)
139
+ # Get normalized Queries as Keys
140
+ K = self._normalize(queries)
141
+ # Zero the masked out keys
142
+ K = K * key_lengths.float_matrix.view(N, L, 1, 1)
143
+
144
+ V_new = 0
145
+ factor = 1/self.rounds
146
+ for i in range(self.rounds):
147
+ V_new = V_new + \
148
+ factor * self._reformer_round(queries, K, values, mask, softmax_temp)
149
+
150
+ return V_new
151
+
152
+
153
+ # Register the attention implementation so that it becomes available in our
154
+ # builders
155
+ AttentionRegistry.register(
156
+ "reformer", ReformerAttention,
157
+ [
158
+ ("chunk_size", Optional(Int, 32)),
159
+ ("bits", Optional(Int, 63)),
160
+ ("rounds", Optional(Int, 4)),
161
+ ("masked", Optional(Bool, False)),
162
+ ("softmax_temp", Optional(Float)),
163
+ ("attention_dropout", Optional(Float, 0.1)),
164
+ ("event_dispatcher", Optional(EventDispatcherInstance, ""))
165
+ ]
166
+ )
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """Allow for the dynamic registration of new attention implementations.
7
+
8
+ This module provides a Registry implementation that other modules can use to
9
+ register attention implementations for the builders.
10
+ """
11
+
12
+ from .registry import \
13
+ AttentionRegistry, \
14
+ RecurrentAttentionRegistry, \
15
+ RecurrentCrossAttentionRegistry
16
+ from .spec import Spec, Choice, Optional, Int, Float, Bool, Callable, \
17
+ EventDispatcherInstance
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (786 Bytes). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+
7
+ class Registry(object):
8
+ """Hold the available attention implementations and their required
9
+ parameters."""
10
+ def __init__(self):
11
+ self._classes = {}
12
+ self._class_params = {}
13
+ self._parameters = {}
14
+
15
+ def register(self, key, class_object, parameter_tuples):
16
+ # register the class if the key is new
17
+ if key in self._classes:
18
+ raise ValueError("{} is already registered".format(key))
19
+ self._classes[key] = class_object
20
+
21
+ # register the parameters
22
+ for parameter, spec in parameter_tuples:
23
+ if (
24
+ parameter in self._parameters and
25
+ self._parameters[parameter] != spec
26
+ ):
27
+ raise ValueError(("{} is already registered with "
28
+ "spec {!r} instead of {!r}").format(
29
+ parameter,
30
+ self._parameters[parameter],
31
+ spec
32
+ ))
33
+ self._parameters[parameter] = spec
34
+
35
+ # note which parameters are needed by this class
36
+ self._class_params[key] = [p for p, s in parameter_tuples]
37
+
38
+ def __contains__(self, key):
39
+ return key in self._classes
40
+
41
+ def __getitem__(self, key):
42
+ return self._classes[key], self._class_params[key]
43
+
44
+ @property
45
+ def keys(self):
46
+ return list(self._classes.keys())
47
+
48
+ def contains_parameter(self, key):
49
+ return key in self._parameters
50
+
51
+ def validate_parameter(self, key, value):
52
+ try:
53
+ return self._parameters[key].get(value)
54
+ except Exception as e:
55
+ raise ValueError(("Invalid value {!r} for "
56
+ "parameter {!r}").format(value, key)) from e
57
+
58
+
59
+ AttentionRegistry = Registry()
60
+ RecurrentAttentionRegistry = Registry()
61
+ RecurrentCrossAttentionRegistry = Registry()
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """Spec instances allow to describe and check the type and value of
7
+ parameters."""
8
+
9
+ from ..events import EventDispatcher
10
+
11
+
12
+ class Spec(object):
13
+ """Describe and validate a parameter type.
14
+
15
+ Arguments
16
+ ---------
17
+ predicate: A callable that checks if the value is acceptable and
18
+ returns its canonical value or raises ValueError.
19
+ name: A name to create a human readable description of the Spec
20
+ """
21
+ def __init__(self, predicate, name="CustomSpec"):
22
+ self._predicate = predicate
23
+ self._name = name
24
+
25
+ def __repr__(self):
26
+ return self._name
27
+
28
+ def check(self, x):
29
+ try:
30
+ self._predicate(x)
31
+ return True
32
+ except ValueError:
33
+ return False
34
+
35
+ def get(self, x):
36
+ return self._predicate(x)
37
+
38
+ def __eq__(self, y):
39
+ return self is y
40
+
41
+
42
+ class Choice(Spec):
43
+ """A parameter type for a set of options.
44
+
45
+ Arguments
46
+ ---------
47
+ choices: A set or list of possible values for this parameter
48
+ """
49
+ def __init__(self, choices):
50
+ self._choices = choices
51
+
52
+ def get(self, x):
53
+ if x in self._choices:
54
+ return x
55
+ raise ValueError("{!r} is not in {!r}".format(x, self._choices))
56
+
57
+ def __repr__(self):
58
+ return "Choice({!r})".format(self._choices)
59
+
60
+ def __eq__(self, x):
61
+ if isinstance(x, Choice):
62
+ return self._choices == x._choices
63
+ return False
64
+
65
+
66
+ class _Callable(Spec):
67
+ def __init__(self):
68
+ super(_Callable, self).__init__(None, "Callable")
69
+
70
+ def get(self, x):
71
+ if callable(x):
72
+ return x
73
+ raise ValueError("{!r} is not a callable".format(x))
74
+
75
+
76
+ class _EventDispatcherInstance(Spec):
77
+ def __init__(self):
78
+ super(_EventDispatcherInstance, self).__init__(
79
+ _EventDispatcherInstance._get_event_dispatcher,
80
+ "EventDispatcherInstance"
81
+ )
82
+
83
+ @staticmethod
84
+ def _get_event_dispatcher(x):
85
+ if isinstance(x, str):
86
+ return x
87
+ if isinstance(x, EventDispatcher):
88
+ return x
89
+ raise ValueError("{!r} is not an event dispatcher".format(x))
90
+
91
+
92
+ class Optional(Spec):
93
+ """Represent an optional parameter that can either have a value or it can
94
+ be None.
95
+
96
+ Arguments
97
+ ---------
98
+ spec: The spec for the value if it is not None
99
+ default: The returned value in case it is None
100
+ """
101
+ def __init__(self, spec, default=None):
102
+ self._other_spec = spec
103
+ self._default = default
104
+
105
+ def __repr__(self):
106
+ return "Optional[{!r}, {!r}]".format(self._other_spec, self._default)
107
+
108
+ def get(self, x):
109
+ if x is None:
110
+ return self._default
111
+ return self._other_spec.get(x)
112
+
113
+ def __eq__(self, x):
114
+ if isinstance(x, Optional):
115
+ return (
116
+ self._other_spec == x._other_spec and
117
+ self._default == x._default
118
+ )
119
+ return False
120
+
121
+
122
+ Int = Spec(int, "Int")
123
+ Float = Spec(float, "Float")
124
+ Bool = Spec(bool, "Bool")
125
+ Callable = _Callable()
126
+ EventDispatcherInstance = _EventDispatcherInstance()
smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """This module implements builders that simplify building complex transformer
8
+ architectures with different attention mechanisms.
9
+
10
+ The main idea is to facilitate the construction of various attention layers and
11
+ transformer encoder layers and simplify their assembly into one transformer
12
+ module. It also allows for flexibility in the scripts as many builder
13
+ parameters can correspond 1-1 with command line arguments.
14
+
15
+ Example usage:
16
+
17
+ builder = TransformerEncoderBuilder()
18
+ builder.n_layers = 12
19
+ builder.n_heads = 8
20
+ builder.feed_forward_dimensions = 1024
21
+ builder.query_dimensions = 64
22
+ builder.value_dimensions = 64
23
+ builder.dropout = 0.1
24
+ builder.attention_dropout = 0.1
25
+ builder.attention_type = "linear"
26
+ transformer = builder.get()
27
+ """
28
+
29
+ __all__ = [
30
+ "AttentionBuilder",
31
+ "RecurrentAttentionBuilder",
32
+ "RecurrentCrossAttentionBuilder"
33
+ ]
34
+
35
+ # Import the attention implementations so that they register themselves with
36
+ # the builder. Attention implementations external to the library should be
37
+ # imported before using the builders.
38
+ #
39
+ # TODO: Should this behaviour change? Namely, should all attention
40
+ # implementations be imported in order to be useable? This also allows
41
+ # using the library even partially built, for instance.
42
+ from ..attention import \
43
+ FullAttention, \
44
+ LinearAttention
45
+
46
+ del FullAttention, \
47
+ LinearAttention
48
+
49
+
50
+ from .attention_builders import \
51
+ AttentionBuilder, \
52
+ RecurrentAttentionBuilder, \
53
+ RecurrentCrossAttentionBuilder
54
+
55
+ from .transformer_builders import \
56
+ TransformerEncoderBuilder, \
57
+ RecurrentEncoderBuilder, \
58
+ TransformerDecoderBuilder, \
59
+ RecurrentDecoderBuilder
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc ADDED
Binary file (2.3 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ from collections import defaultdict
7
+
8
+ from .base import BaseBuilder
9
+ from ..attention_registry import \
10
+ AttentionRegistry, \
11
+ RecurrentAttentionRegistry, \
12
+ RecurrentCrossAttentionRegistry
13
+
14
+
15
+ class BaseAttentionBuilder(BaseBuilder):
16
+ def __init__(self, registry):
17
+ self._registry = registry
18
+ self._parameters = defaultdict(lambda: None)
19
+
20
+ @property
21
+ def available_attentions(self):
22
+ """Return a list with the available attention implementations."""
23
+ return self._registry.keys
24
+
25
+ def validate_attention_type(self, attention_type):
26
+ """Parse the attention type according to the rules used by `get()` and
27
+ check if the requested attention is constructible."""
28
+ return all(
29
+ all(t in self._registry for t in a.split(","))
30
+ for a in attention_type.split(":")
31
+ )
32
+
33
+ def __setattr__(self, key, value):
34
+ # Make sure we have normal behaviour for the class members _registry
35
+ # and _parameters
36
+ if key in ["_registry", "_parameters"]:
37
+ return object.__setattr__(self, key, value)
38
+
39
+ # Assign everything else in the parameters dictionary
40
+ if not self._registry.contains_parameter(key):
41
+ raise AttributeError(("{!r} is not a valid attention "
42
+ "parameter name").format(key))
43
+ self._parameters[key] = self._registry.validate_parameter(key, value)
44
+
45
+ def __getattr__(self, key):
46
+ if key in self._parameters:
47
+ return self._parameters[key]
48
+ else:
49
+ raise AttributeError()
50
+
51
+ def __repr__(self):
52
+ return (
53
+ "{}.from_kwargs(\n".format(self.__class__.__name__) +
54
+ "\n".join([" {}={!r},".format(k, v)
55
+ for k, v in self._parameters.items()])[:-1] +
56
+ "\n)"
57
+ )
58
+
59
+ def get(self, attention_type):
60
+ """Construct the attention implementation object and return it.
61
+
62
+ The passed in attention_type argument defines the attention to be
63
+ created. It should be a string and in its simplest form it should
64
+ be one of the available choices from `available_attentions`.
65
+
66
+ However, to enable attention decoration, namely an attention
67
+ implementation augmenting the functionality of another implementation,
68
+ the attention type can be a colon separated list of compositions like
69
+ the following examples:
70
+
71
+ - 'att1' means instantiate att1
72
+ - 'att2:att1' means instantiate att1 and decorate it with att2
73
+ - 'att3:att1,att4' means instantiate att1 and att4 and decorate
74
+ them with att3
75
+
76
+ Arguments
77
+ ---------
78
+ attention_type: A string that contains one or more keys from
79
+ `available_attentions` separated with a colon to
80
+ denote the decoration pattern.
81
+ """
82
+ compositions = reversed(attention_type.split(":"))
83
+ attentions = []
84
+ for c in compositions:
85
+ attentions = [
86
+ self._construct_attention(t, attentions)
87
+ for t in c.split(",")
88
+ ]
89
+ if len(attentions) > 1:
90
+ raise ValueError(("Invalid attention_type argument "
91
+ "{!r}").format(attention_type))
92
+ return attentions[0]
93
+
94
+ def _construct_attention(self, attention_type, decorated=[]):
95
+ """Construct an attention implementation object.
96
+
97
+ Arguments
98
+ ---------
99
+ attention_type: A string that contains a single key from the
100
+ `available_attentions`
101
+ decorated: A list of attention implementations to pass as arguments
102
+ to be decorated
103
+ """
104
+ if attention_type not in self._registry:
105
+ raise ValueError(("Unknown attention type "
106
+ "{!r}").format(attention_type))
107
+
108
+ attention, parameters = self._registry[attention_type]
109
+ parameter_dictionary = {
110
+ p: self._registry.validate_parameter(p, self._parameters[p])
111
+ for p in parameters
112
+ }
113
+
114
+ return attention(*decorated, **parameter_dictionary)
115
+
116
+
117
+ class AttentionBuilder(BaseAttentionBuilder):
118
+ """Build attention implementations for batch sequence processing or
119
+ training."""
120
+ def __init__(self):
121
+ super(AttentionBuilder, self).__init__(AttentionRegistry)
122
+
123
+
124
+ class RecurrentAttentionBuilder(BaseAttentionBuilder):
125
+ """Build attention implementations for autoregressive sequence
126
+ processing."""
127
+ def __init__(self):
128
+ super(RecurrentAttentionBuilder, self).__init__(
129
+ RecurrentAttentionRegistry
130
+ )
131
+
132
+
133
+ class RecurrentCrossAttentionBuilder(BaseAttentionBuilder):
134
+ """Build attention implementations for autoregressive cross attention
135
+ computation."""
136
+ def __init__(self):
137
+ super(RecurrentCrossAttentionBuilder, self).__init__(
138
+ RecurrentCrossAttentionRegistry
139
+ )
smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ """Provide a class for the others to inherit some useful functionality."""
8
+
9
+
10
+ class BaseBuilder(object):
11
+ @classmethod
12
+ def from_kwargs(cls, **kwargs):
13
+ """Construct a builder and set all the keyword arguments as parameters.
14
+
15
+ The keyword argument strict is passed to
16
+ BaseBuilder.from_dictionary separately.
17
+
18
+ See BaseBuilder.from_dictionary().
19
+ """
20
+ strict = kwargs.pop("strict", True)
21
+ return cls.from_dictionary(kwargs, strict=strict)
22
+
23
+ @classmethod
24
+ def from_namespace(cls, args, strict=False):
25
+ """Construct a builder from an argparse Namespace.
26
+
27
+ To be used for building transformers from command line arguments.
28
+
29
+ See BaseBuilder.from_dictionary().
30
+ """
31
+ return cls.from_dictionary(vars(args), strict=strict)
32
+
33
+ @classmethod
34
+ def from_dictionary(cls, dictionary, strict=True):
35
+ """Construct a builder and set all the parameters in the dictionary.
36
+
37
+ Given a dictionary
38
+
39
+ d = {"foo": "bar"}
40
+
41
+ then
42
+
43
+ builder = TransformerEncoderBuilder.from_dictionary(d)
44
+
45
+ is equivalent to
46
+
47
+ builder = TransformerEncoderBuilder()
48
+ builder.foo = "bar"
49
+
50
+ Arguments
51
+ ---------
52
+ dictionary: A dictionary of parameters to set to the builder.
53
+ strict: bool, If a key is not a parameter and strict is set to True
54
+ then a ValueError is raised, otherwise that dictionary key
55
+ is ignored (default: True)
56
+ """
57
+ builder = cls()
58
+ for k, v in dictionary.items():
59
+ try:
60
+ setattr(builder, k, v)
61
+ except AttributeError:
62
+ if strict:
63
+ raise ValueError(("The builder has no "
64
+ "parameter {!r}").format(k))
65
+ else:
66
+ continue
67
+ return builder
smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """Build complex transformer architectures for inference or training easily."""
7
+
8
+ from torch.nn import LayerNorm
9
+
10
+ from ..attention import AttentionLayer
11
+ from ..transformers import TransformerEncoder, TransformerEncoderLayer, \
12
+ TransformerDecoder, TransformerDecoderLayer
13
+ from ..recurrent.attention import \
14
+ RecurrentAttentionLayer, \
15
+ RecurrentCrossAttentionLayer
16
+ from ..recurrent.transformers import \
17
+ RecurrentTransformerEncoder, RecurrentTransformerEncoderLayer, \
18
+ RecurrentTransformerDecoder, RecurrentTransformerDecoderLayer
19
+ from .base import BaseBuilder
20
+ from .attention_builders import AttentionBuilder, RecurrentAttentionBuilder, \
21
+ RecurrentCrossAttentionBuilder
22
+
23
+
24
+ class BaseTransformerBuilder(BaseBuilder):
25
+ """Contains all the parameters for building a transformer other than the
26
+ attention part.
27
+
28
+ Classes extending the BaseTransformerBuilder should implement the `get()`
29
+ method that actually builds the transformer.
30
+ """
31
+ def __init__(self):
32
+ # transformer parameters
33
+ self._n_layers = 4
34
+ self._n_heads = 4
35
+ self._d_query = 64
36
+ self._d_value = 64
37
+ self._d_ff = 1024
38
+ self._dropout = 0.1
39
+ self._activation = "relu"
40
+ self._final_norm = True
41
+ self._event_dispatcher = "" # the default global dispatcher
42
+
43
+ @property
44
+ def n_layers(self):
45
+ """The number of transformer layers."""
46
+ return self._n_layers
47
+
48
+ @n_layers.setter
49
+ def n_layers(self, val):
50
+ self._n_layers = val
51
+
52
+ @property
53
+ def n_heads(self):
54
+ """The number of heads in each transformer layer."""
55
+ return self._n_heads
56
+
57
+ @n_heads.setter
58
+ def n_heads(self, val):
59
+ self._n_heads = val
60
+
61
+ @property
62
+ def feed_forward_dimensions(self):
63
+ """The dimensions of the fully connected layer in the transformer
64
+ layers."""
65
+ return self._d_ff
66
+
67
+ @feed_forward_dimensions.setter
68
+ def feed_forward_dimensions(self, val):
69
+ self._d_ff = val
70
+
71
+ @property
72
+ def query_dimensions(self):
73
+ """The dimensions of the queries and keys in each attention layer."""
74
+ return self._d_query
75
+
76
+ @query_dimensions.setter
77
+ def query_dimensions(self, val):
78
+ self._d_query = val
79
+
80
+ @property
81
+ def value_dimensions(self):
82
+ """The dimensions of the values in each attention layer."""
83
+ return self._d_value
84
+
85
+ @value_dimensions.setter
86
+ def value_dimensions(self, val):
87
+ self._d_value = val
88
+
89
+ @property
90
+ def dropout(self):
91
+ """The dropout rate to be applied in the transformer encoder layer."""
92
+ return self._dropout
93
+
94
+ @dropout.setter
95
+ def dropout(self, val):
96
+ self._dropout = val
97
+
98
+ @property
99
+ def activation(self):
100
+ """The activation function for the transformer layer.
101
+
102
+ One of {'relu', 'gelu'}.
103
+ """
104
+ return self._activation
105
+
106
+ @activation.setter
107
+ def activation(self, val):
108
+ activations = ["relu", "gelu"]
109
+ if val not in activations:
110
+ raise ValueError(("{!r} is not one of the availabel activation "
111
+ "types {!r}").format(val, activations))
112
+ self._activation = val
113
+
114
+ @property
115
+ def final_normalization(self):
116
+ """Whether to add LayerNorm as the final layer of the
117
+ TransformerEncoder."""
118
+ return self._final_norm
119
+
120
+ @final_normalization.setter
121
+ def final_normalization(self, val):
122
+ self._final_norm = bool(val)
123
+
124
+ @property
125
+ def event_dispatcher(self):
126
+ """The transformer event dispatcher either as a string or as an
127
+ EventDispatcher object."""
128
+ return self._event_dispatcher
129
+
130
+ @event_dispatcher.setter
131
+ def event_dispatcher(self, event_dispatcher):
132
+ self._event_dispatcher = event_dispatcher
133
+
134
+ def get(self):
135
+ """Build the transformer and return it."""
136
+ raise NotImplementedError()
137
+
138
+
139
+ class BaseTransformerEncoderBuilder(BaseTransformerBuilder):
140
+ """Implement the logic of building a transformer encoder but leave the
141
+ specific layers open for changing by the inheriting classes. This allows us
142
+ to reuse the logic for creating both the TransformerEncoder and the
143
+ RecurrentTransformerEncoder.
144
+
145
+ Inheriting classes should implement the following:
146
+
147
+ - _get_attention_builder()
148
+ - _get_attention_layer_class()
149
+ - _get_encoder_class()
150
+ - _get_encoder_layer_class()
151
+ """
152
+ def __init__(self):
153
+ super(BaseTransformerEncoderBuilder, self).__init__()
154
+ self._attention_builder = self._get_attention_builder()
155
+ self._attention_type = "full"
156
+
157
+ def _get_attention_builder(self):
158
+ """Return an instance of the appropriate attention builder."""
159
+ raise NotImplementedError()
160
+
161
+ def _get_attention_layer_class(self):
162
+ """Return the class for the layer that projects queries keys and
163
+ values."""
164
+ raise NotImplementedError()
165
+
166
+ def _get_encoder_class(self):
167
+ """Return the class for the transformer encoder."""
168
+ raise NotImplementedError()
169
+
170
+ def _get_encoder_layer_class(self):
171
+ """Return the class for the transformer encoder layer."""
172
+ raise NotImplementedError()
173
+
174
+ @property
175
+ def attention(self):
176
+ """The attention builder instance."""
177
+ return self._attention_builder
178
+
179
+ @property
180
+ def attention_type(self):
181
+ """The attention implementation chosen."""
182
+ return self._attention_type
183
+
184
+ @attention_type.setter
185
+ def attention_type(self, val):
186
+ if not self._attention_builder.validate_attention_type(val):
187
+ raise ValueError(("{!r} is not an available attention "
188
+ "type").format(val))
189
+ self._attention_type = val
190
+
191
+ def __setattr__(self, key, val):
192
+ # "protected" attributes are settable (probably from withing the class)
193
+ if key[0] == "_":
194
+ return super().__setattr__(key, val)
195
+
196
+ # Existing attributes are settable but they might also be attention
197
+ # parameters so try that as well
198
+ fail_on_exception = True
199
+ if hasattr(self, key):
200
+ super().__setattr__(key, val)
201
+ fail_on_exception = False
202
+
203
+ # Non-existing "public" attributes may be attention parameters
204
+ try:
205
+ setattr(self._attention_builder, key, val)
206
+ except:
207
+ if fail_on_exception:
208
+ raise
209
+
210
+ def get(self):
211
+ """Build the transformer and return it."""
212
+ # Set the event dispatcher to the attention builder
213
+ self.attention.event_dispatcher = self.event_dispatcher
214
+
215
+ # Extract into local variables the classes to be used
216
+ Encoder = self._get_encoder_class()
217
+ EncoderLayer = self._get_encoder_layer_class()
218
+ Attention = self._get_attention_layer_class()
219
+
220
+ model_dimensions = self.value_dimensions*self.n_heads
221
+ return Encoder(
222
+ [
223
+ EncoderLayer(
224
+ Attention(
225
+ self.attention.get(self.attention_type),
226
+ model_dimensions,
227
+ self.n_heads,
228
+ d_keys=self.query_dimensions,
229
+ d_values=self.value_dimensions,
230
+ event_dispatcher=self.event_dispatcher
231
+ ),
232
+ model_dimensions,
233
+ self.feed_forward_dimensions,
234
+ self.dropout,
235
+ self.activation,
236
+ event_dispatcher=self.event_dispatcher
237
+ )
238
+ for _ in range(self.n_layers)
239
+ ],
240
+ (LayerNorm(model_dimensions) if self.final_normalization else None),
241
+ event_dispatcher=self.event_dispatcher
242
+ )
243
+
244
+
245
+ class TransformerEncoderBuilder(BaseTransformerEncoderBuilder):
246
+ """Build a batch transformer encoder for training or processing of
247
+ sequences all elements at a time.
248
+
249
+ Example usage:
250
+
251
+ builder = TransformerEncoderBuilder()
252
+ builder.n_layers = 12
253
+ builder.n_heads = 8
254
+ builder.feed_forward_dimensions = 1024
255
+ builder.query_dimensions = 64
256
+ builder.value_dimensions = 64
257
+ builder.dropout = 0.1
258
+ builder.attention_dropout = 0.1
259
+ builder.attention_type = "linear"
260
+ transformer = builder.get()
261
+ """
262
+ def _get_attention_builder(self):
263
+ """Return an instance of the appropriate attention builder."""
264
+ return AttentionBuilder()
265
+
266
+ def _get_attention_layer_class(self):
267
+ """Return the class for the layer that projects queries keys and
268
+ values."""
269
+ return AttentionLayer
270
+
271
+ def _get_encoder_class(self):
272
+ """Return the class for the transformer encoder."""
273
+ return TransformerEncoder
274
+
275
+ def _get_encoder_layer_class(self):
276
+ """Return the class for the transformer encoder layer."""
277
+ return TransformerEncoderLayer
278
+
279
+
280
+ class RecurrentEncoderBuilder(BaseTransformerEncoderBuilder):
281
+ """Build a transformer encoder for autoregressive processing of sequences.
282
+
283
+ Example usage:
284
+
285
+ builder = RecurrentEncoderBuilder()
286
+ builder.n_layers = 12
287
+ builder.n_heads = 8
288
+ builder.feed_forward_dimensions = 1024
289
+ builder.query_dimensions = 64
290
+ builder.value_dimensions = 64
291
+ builder.dropout = 0.1
292
+ builder.attention_dropout = 0.1
293
+ builder.attention_type = "linear"
294
+ transformer = builder.get()
295
+ """
296
+ def _get_attention_builder(self):
297
+ """Return an attention builder for recurrent attention."""
298
+ return RecurrentAttentionBuilder()
299
+
300
+ def _get_attention_layer_class(self):
301
+ """Return the class for the recurrent layer that projects queries keys
302
+ and values."""
303
+ return RecurrentAttentionLayer
304
+
305
+ def _get_encoder_class(self):
306
+ """Return the class for the recurrent transformer encoder."""
307
+ return RecurrentTransformerEncoder
308
+
309
+ def _get_encoder_layer_class(self):
310
+ """Return the class for the recurrent transformer encoder layer."""
311
+ return RecurrentTransformerEncoderLayer
312
+
313
+
314
+ class BaseTransformerDecoderBuilder(BaseTransformerBuilder):
315
+ """Similar to BaseTransformerEncoderBuilder implement the logic of
316
+ building the transformer decoder without defining concrete layers.
317
+
318
+ Inheriting classes should implement the following:
319
+
320
+ - _get_self_attention_builder() and _get_cross_attention_builder()
321
+ - _get_self_attention_layer_class() and _get_cross_attention_layer_class()
322
+ - _get_decoder_class()
323
+ - _get_decoder_layer_class()
324
+ """
325
+ def __init__(self):
326
+ super(BaseTransformerDecoderBuilder, self).__init__()
327
+ self._self_attention_builder = self._get_self_attention_builder()
328
+ self._cross_attention_builder = self._get_cross_attention_builder()
329
+ self._self_attention_type = "full"
330
+ self._cross_attention_type = "full"
331
+
332
+ def _get_self_attention_builder(self):
333
+ """Return an instance of attention builder."""
334
+ raise NotImplementedError()
335
+
336
+ def _get_cross_attention_builder(self):
337
+ """Return an instance of attention builder."""
338
+ raise NotImplementedError()
339
+
340
+ def _get_self_attention_layer_class(self):
341
+ """Return a class to project the queries, keys and values to
342
+ multi-head versions."""
343
+ raise NotImplementedError()
344
+
345
+ def _get_cross_attention_layer_class(self):
346
+ """Return a class to project the queries, keys and values to
347
+ multi-head versions."""
348
+ raise NotImplementedError()
349
+
350
+ def _get_decoder_class(self):
351
+ """Return the class for the transformer decoder."""
352
+ raise NotImplementedError()
353
+
354
+ def _get_decoder_layer_class(self):
355
+ """Return the class for the transformer decoder layer."""
356
+ raise NotImplementedError()
357
+
358
+ @property
359
+ def self_attention(self):
360
+ """The attention builder instance that will be used for the self
361
+ attention modules."""
362
+ return self._self_attention_builder
363
+
364
+ @property
365
+ def self_attention_type(self):
366
+ """The attention implementation used for self attention."""
367
+ return self._self_attention_type
368
+
369
+ @self_attention_type.setter
370
+ def self_attention_type(self, val):
371
+ if not self._self_attention_builder.validate_attention_type(val):
372
+ raise ValueError(("{!r} is not an available self attention "
373
+ "type").format(val))
374
+ self._self_attention_type = val
375
+
376
+ @property
377
+ def cross_attention(self):
378
+ """The attention builder instance that will be used for the cross
379
+ attention modules."""
380
+ return self._cross_attention_builder
381
+
382
+ @property
383
+ def cross_attention_type(self):
384
+ """The attention implementation used for cross attention."""
385
+ return self._cross_attention_type
386
+
387
+ @cross_attention_type.setter
388
+ def cross_attention_type(self, val):
389
+ if not self._cross_attention_builder.validate_attention_type(val):
390
+ raise ValueError(("{!r} is not an available cross attention "
391
+ "type").format(val))
392
+ self._cross_attention_type = val
393
+
394
+ def __setattr__(self, key, val):
395
+ # "protected" attributes are settable (probably from withing the class)
396
+ if key[0] == "_":
397
+ return super().__setattr__(key, val)
398
+
399
+ # Existing attributes are settable but they might also be attention
400
+ # parameters so try that as well
401
+ fail_on_exception = True
402
+ if hasattr(self, key):
403
+ super().__setattr__(key, val)
404
+ fail_on_exception = False
405
+
406
+ # Non-existing "public" attributes may be attention parameters
407
+ try:
408
+ setattr(self._self_attention_builder, key, val)
409
+ setattr(self._cross_attention_builder, key, val)
410
+ except:
411
+ if fail_on_exception:
412
+ raise
413
+
414
+ def get(self):
415
+ """Build the transformer and return it."""
416
+ # Set the event dispatcher to attention builders
417
+ self.self_attention.event_dispatcher = self.event_dispatcher
418
+ self.cross_attention.event_dispatcher = self.event_dispatcher
419
+
420
+ # Extract into local variables the classes to be used
421
+ Decoder = self._get_decoder_class()
422
+ DecoderLayer = self._get_decoder_layer_class()
423
+ SelfAttention = self._get_self_attention_layer_class()
424
+ CrossAttention = self._get_cross_attention_layer_class()
425
+
426
+ model_dimensions = self.value_dimensions*self.n_heads
427
+ return Decoder(
428
+ [
429
+ DecoderLayer(
430
+ SelfAttention(
431
+ self.self_attention.get(self.self_attention_type),
432
+ model_dimensions,
433
+ self.n_heads,
434
+ d_keys=self.query_dimensions,
435
+ d_values=self.value_dimensions,
436
+ event_dispatcher=self.event_dispatcher
437
+ ),
438
+ CrossAttention(
439
+ self.cross_attention.get(self.cross_attention_type),
440
+ model_dimensions,
441
+ self.n_heads,
442
+ d_keys=self.query_dimensions,
443
+ d_values=self.value_dimensions,
444
+ event_dispatcher=self.event_dispatcher
445
+ ),
446
+ model_dimensions,
447
+ self.feed_forward_dimensions,
448
+ self.dropout,
449
+ self.activation,
450
+ event_dispatcher=self.event_dispatcher
451
+ )
452
+ for _ in range(self.n_layers)
453
+ ],
454
+ (LayerNorm(model_dimensions) if self.final_normalization else None),
455
+ event_dispatcher=self.event_dispatcher
456
+ )
457
+
458
+
459
+ class TransformerDecoderBuilder(BaseTransformerDecoderBuilder):
460
+ """Build a transformer decoder for training or processing of sequences all
461
+ elements at a time.
462
+
463
+ Example usage:
464
+
465
+ builder = TransformerDecoderBuilder()
466
+ builder.n_layers = 12
467
+ builder.n_heads = 8
468
+ builder.feed_forward_dimensions = 1024
469
+ builder.query_dimensions = 64
470
+ builder.value_dimensions = 64
471
+ builder.dropout = 0.1
472
+ builder.attention_dropout = 0.1
473
+ builder.self_attention_type = "full"
474
+ builder.cross_attention_type = "full"
475
+ transformer = builder.get()
476
+ """
477
+ def _get_self_attention_builder(self):
478
+ """Return an attention builder for creating non-recurrent attention
479
+ variants."""
480
+ return AttentionBuilder()
481
+
482
+ def _get_cross_attention_builder(self):
483
+ """Return an attention builder for creating non-recurrent attention
484
+ variants."""
485
+ return AttentionBuilder()
486
+
487
+ def _get_self_attention_layer_class(self):
488
+ """Return the non-recurrent attention layer to project queries, keys
489
+ and values."""
490
+ return AttentionLayer
491
+
492
+ def _get_cross_attention_layer_class(self):
493
+ """Return the non-recurrent attention layer to project queries, keys
494
+ and values."""
495
+ return AttentionLayer
496
+
497
+ def _get_decoder_class(self):
498
+ """Return the transformer decoder class."""
499
+ return TransformerDecoder
500
+
501
+ def _get_decoder_layer_class(self):
502
+ """Return the transformer decoder layer class."""
503
+ return TransformerDecoderLayer
504
+
505
+
506
+ class RecurrentDecoderBuilder(BaseTransformerDecoderBuilder):
507
+ """Build a transformer decoder for processing of sequences in
508
+ autoregressive fashion.
509
+
510
+ Example usage:
511
+
512
+ builder = RecurrentDecoderBuilder()
513
+ builder.n_layers = 12
514
+ builder.n_heads = 8
515
+ builder.feed_forward_dimensions = 1024
516
+ builder.query_dimensions = 64
517
+ builder.value_dimensions = 64
518
+ builder.dropout = 0.1
519
+ builder.attention_dropout = 0.1
520
+ builder.self_attention_type = "full"
521
+ builder.cross_attention_type = "full"
522
+ transformer = builder.get()
523
+ """
524
+ def _get_self_attention_builder(self):
525
+ """Return an attention builder for creating non-recurrent attention
526
+ variants."""
527
+ return RecurrentAttentionBuilder()
528
+
529
+ def _get_cross_attention_builder(self):
530
+ """Return an attention builder for creating non-recurrent attention
531
+ variants."""
532
+ return RecurrentCrossAttentionBuilder()
533
+
534
+ def _get_self_attention_layer_class(self):
535
+ """Return the non-recurrent attention layer to project queries, keys
536
+ and values."""
537
+ return RecurrentAttentionLayer
538
+
539
+ def _get_cross_attention_layer_class(self):
540
+ """Return the non-recurrent attention layer to project queries, keys
541
+ and values."""
542
+ return RecurrentCrossAttentionLayer
543
+
544
+ def _get_decoder_class(self):
545
+ """Return the transformer decoder class."""
546
+ return RecurrentTransformerDecoder
547
+
548
+ def _get_decoder_layer_class(self):
549
+ """Return the transformer decoder layer class."""
550
+ return RecurrentTransformerDecoderLayer
smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+ import torch
8
+
9
+ from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \
10
+ causal_dot_backward as causal_dot_backward_cpu
11
+
12
+ try:
13
+ from .causal_product_cuda import \
14
+ causal_dot_product as causal_dot_product_cuda, \
15
+ causal_dot_backward as causal_dot_backward_cuda
16
+ except ImportError:
17
+ causal_dot_product_cuda = causal_dot_backward_cuda = None
18
+
19
+
20
+ class CausalDotProduct(torch.autograd.Function):
21
+ """Compute the weighted sum of values but attending only to previous
22
+ values."""
23
+ dot = {
24
+ "cpu": causal_dot_product_cpu,
25
+ "cuda": causal_dot_product_cuda
26
+ }
27
+ dot_backward = {
28
+ "cpu": causal_dot_backward_cpu,
29
+ "cuda": causal_dot_backward_cuda
30
+ }
31
+
32
+ @staticmethod
33
+ def forward(ctx, Q, K, V):
34
+ # Save the inputs for the gradient computation
35
+ ctx.save_for_backward(Q, K, V)
36
+
37
+ # Create the output tensor
38
+ device = Q.device
39
+ N, H, L, _ = Q.shape
40
+ _, _, _, M = V.shape
41
+ product = torch.zeros((N, H, L, M), device=device)
42
+
43
+ # Actually perform the dot product
44
+ CausalDotProduct.dot[device.type](
45
+ Q.data,
46
+ K.data,
47
+ V.data,
48
+ product
49
+ )
50
+
51
+ return product
52
+
53
+ @staticmethod
54
+ def backward(ctx, grad_out):
55
+ # Extract the saved tensors
56
+ Q, K, V = ctx.saved_tensors
57
+
58
+ # Allocate memory for the gradients
59
+ grad_Q = torch.zeros_like(Q)
60
+ grad_K = torch.zeros_like(K)
61
+ grad_V = torch.zeros_like(V)
62
+
63
+ # Actually compute the gradients
64
+ CausalDotProduct.dot_backward[Q.device.type](
65
+ Q.data,
66
+ K.data,
67
+ V.data,
68
+ grad_out,
69
+ grad_Q,
70
+ grad_K,
71
+ grad_V
72
+ )
73
+
74
+ return grad_Q, grad_K, grad_V
75
+
76
+
77
+ # Alias the autograd functions to python style snake case naming
78
+ causal_dot_product = CausalDotProduct.apply
smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f32370e707beebd8fee88f356fb62721096142265895a5a8e9872063c04595
3
+ size 140928
smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py ADDED
File without changes
smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>,
4
+ # Apoorv Vyas <[email protected]>
5
+ #
6
+
7
+
8
+ import numpy as np
9
+
10
+ import torch
11
+
12
+ from .cluster_cpu import cluster as cluster_cpu
13
+ try:
14
+ from .cluster_cuda import cluster as cluster_gpu
15
+ except ImportError:
16
+ pass
17
+
18
+
19
+ def cluster(
20
+ hashes,
21
+ lengths,
22
+ groups=None,
23
+ counts=None,
24
+ centroids=None,
25
+ distances=None,
26
+ bitcounts=None,
27
+ clusters=30,
28
+ iterations=10,
29
+ bits=32
30
+ ):
31
+ """Cluster hashes using a few iterations of K-Means with hamming distance.
32
+
33
+ All the tensors default initialized to None are optional buffers to avoid
34
+ memory allocations. distances and bitcounts are only used by the CUDA
35
+ version of this call. clusters will be ignored if centroids is provided.
36
+
37
+ Arguments
38
+ ---------
39
+ hashes: A long tensor of shape (N, H, L) containing a hashcode for each
40
+ query.
41
+ lengths: An int tensor of shape (N,) containing the sequence length for
42
+ each sequence in hashes.
43
+ groups: An int tensor buffer of shape (N, H, L) contaning the cluster
44
+ in which the corresponding hash belongs to.
45
+ counts: An int tensor buffer of shape (N, H, K) containing the number
46
+ of elements in each cluster.
47
+ centroids: A long tensor buffer of shape (N, H, K) containing the
48
+ centroid for each cluster.
49
+ distances: An int tensor of shape (N, H, L) containing the distance to
50
+ the closest centroid for each hash.
51
+ bitcounts: An int tensor of shape (N, H, K, bits) containing the number
52
+ of elements that have 1 for a given bit.
53
+ clusters: The number of clusters to use for each sequence. It is
54
+ ignored if centroids is not None.
55
+ iterations: How many k-means iterations to perform.
56
+ bits: How many of the least-significant bits in hashes to consider.
57
+
58
+ Returns
59
+ -------
60
+ groups and counts as defined above.
61
+ """
62
+ device = hashes.device
63
+ N, H, L = hashes.shape
64
+
65
+ # Unfortunately cpu and gpu have different APIs so the entire call must be
66
+ # surrounded by an if-then-else
67
+ if device.type == "cpu":
68
+ if groups is None:
69
+ groups = torch.empty((N, H, L), dtype=torch.int32)
70
+ if centroids is None:
71
+ centroids = torch.empty((N, H, clusters), dtype=torch.int64)
72
+ centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
73
+ K = centroids.shape[2]
74
+ if counts is None:
75
+ counts = torch.empty((N, H, K), dtype=torch.int32)
76
+
77
+ cluster_cpu(
78
+ hashes, lengths,
79
+ centroids, groups, counts,
80
+ iterations, bits
81
+ )
82
+
83
+ return groups, counts
84
+
85
+ else:
86
+ if groups is None:
87
+ groups = torch.empty((N, H, L), dtype=torch.int32, device=device)
88
+ if centroids is None:
89
+ centroids = torch.empty((N, H, clusters), dtype=torch.int64,
90
+ device=device)
91
+ centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
92
+ K = centroids.numel() // N // H
93
+ #K = clusters
94
+ if counts is None:
95
+ counts = torch.empty((N, H, K), dtype=torch.int32, device=device)
96
+ if distances is None:
97
+ distances = torch.empty((N, H, L), dtype=torch.int32,
98
+ device=device)
99
+ if bitcounts is None:
100
+ bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32,
101
+ device=device)
102
+ groups = groups.view(N, H, L)
103
+ counts = counts.view(N, H, K)
104
+ centroids = centroids.view(N, H, K)
105
+ distances = distances.view(N, H, L)
106
+ bitcounts = bitcounts.view(N, H, K, -1)
107
+
108
+ cluster_gpu(
109
+ hashes, lengths,
110
+ centroids, distances, bitcounts, groups, counts,
111
+ iterations, bits
112
+ )
113
+
114
+ return groups, counts
115
+
smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2bd8f761d6e1efdeea33665cad8702b5c07d1a0db728d19cf332c4383510d45
3
+ size 139824
smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """This module implements a basic event system that allows the transformer
7
+ internal components to make available any tensor with minimal overhead."""
8
+
9
+ from .event import Event, AttentionEvent, QKVEvent
10
+ from .event_dispatcher import EventDispatcher
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (556 Bytes). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc ADDED
Binary file (3.5 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc ADDED
Binary file (5.82 kB). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/events/event.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+
7
+ class Event(object):
8
+ """The Event is the base class for all events that are dispatched from any
9
+ transformer module.
10
+
11
+ This class defines only the basic attributes of an event without any
12
+ payload.
13
+
14
+ Arguments
15
+ ---------
16
+ source: torch.nn.Module instance that dispatched this event
17
+ """
18
+ def __init__(self, source):
19
+ self.source = source
20
+
21
+
22
+ class AttentionEvent(Event):
23
+ """An event containing an attention matrix.
24
+
25
+ Arguments
26
+ ---------
27
+ source: torch.nn.Module instance that dispatched this event
28
+ attention_matrix: torch.tensor of the multihead attention matrix
29
+ computed in the corresponding attention layer
30
+ """
31
+ def __init__(self, source, attention_matrix):
32
+ super(AttentionEvent, self).__init__(source)
33
+ self.attention_matrix = attention_matrix
34
+
35
+
36
+ class QKVEvent(Event):
37
+ """An event containing the queries, keys and values projected in their
38
+ multiple heads.
39
+
40
+ Arguments
41
+ ---------
42
+ source: torch.nn.Module instance that dispatched this event
43
+ queries: torch.tensor containing the queries in shape NLHE
44
+ keys: torch.tensor containing the keys in shape NSHE
45
+ values: torch.tensor containing the values in shape NSHD
46
+ """
47
+ def __init__(self, source, queries, keys, values):
48
+ super(QKVEvent, self).__init__(source)
49
+ self.queries = queries
50
+ self.keys = keys
51
+ self.values = values
smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ from collections import OrderedDict
7
+
8
+ from .event import Event
9
+ from .filters import event_class
10
+
11
+
12
+ class EventDispatcher(object):
13
+ """An EventDispatcher is a simple way to implement an observer pattern for
14
+ loose coupling of components. In our case it is used so that the internals
15
+ of large neural networks can communicate with the outside world in an
16
+ agnostic and efficient way.
17
+
18
+ Example usage
19
+ -------------
20
+
21
+ from fast_transformers.events import EventDispatcher, AttentionEvent
22
+ from fast_transformers.events.filters import \
23
+ layer_name_contains
24
+
25
+ def attention_event_handler(event):
26
+ print(event.attention_matrix)
27
+
28
+ ed = EventDispatcher()
29
+ ed.listen(AttentionEvent, attention_event_handler)
30
+ ed.listen(
31
+ AttentionEvent & layer_name_contains("layers.12"),
32
+ attention_event_handler
33
+ )
34
+ """
35
+ _dispatchers = {}
36
+
37
+ def __init__(self):
38
+ self._listeners = OrderedDict()
39
+
40
+ def listen(self, event_filter, event_handler):
41
+ """Add an event handler for the events that pass the event filter.
42
+
43
+ Arguments
44
+ ---------
45
+ event_filter: callable or Event class to define for which events
46
+ this handler will be called
47
+ event_handler: callable that accepts an instance of Event
48
+ """
49
+ if isinstance(event_filter, type) and issubclass(event_filter, Event):
50
+ event_filter = event_class(event_filter)
51
+
52
+ self._listeners[event_handler] = event_filter
53
+
54
+ def remove(self, event_handler):
55
+ """Remove the event_handler from the listeners so that no more events
56
+ are dispatched to this handler."""
57
+ self._listeners.pop(event_handler, None)
58
+
59
+ def clear(self):
60
+ """Remove all listeners from the event dispatcher."""
61
+ self._listeners.clear()
62
+
63
+ def dispatch(self, event):
64
+ """Dispatch an event to the listeners.
65
+
66
+ Arguments
67
+ ---------
68
+ event: Event instance
69
+ """
70
+ for event_handler, event_filter in self._listeners.items():
71
+ if event_filter(event):
72
+ event_handler(event)
73
+
74
+ @classmethod
75
+ def get(cls, key=""):
76
+ """Factory method for creating global event dispatchers for loosely
77
+ coupling parts of a larger codebase.
78
+
79
+ Since global objects are a complete antipattern, we suggest that this
80
+ is only used to set a default value for an event dispatcher passed as
81
+ an argument.
82
+
83
+ Argument
84
+ --------
85
+ key: A key to uniquely identify a dispatcher or an instance of a
86
+ dispatcher to be returned as is
87
+ """
88
+ if isinstance(key, cls):
89
+ return key
90
+ if key not in cls._dispatchers:
91
+ cls._dispatchers[key] = cls()
92
+ return cls._dispatchers[key]
smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """Define composable functions to filter events."""
7
+
8
+ import weakref
9
+
10
+ from .event import Event
11
+
12
+
13
+ class EventFilter(object):
14
+ """EventFilter instances are predicates (ie functions that return True or
15
+ False) to be used with an event dispatcher for filtering event
16
+ instances.
17
+
18
+ The main benefit from using raw functions is that an EventFilter composes
19
+ very easily using operators such as &, |, ~.
20
+
21
+ Example
22
+ --------
23
+
24
+ event_filter = AttentionEvent | layer_name_contains("layers.1")
25
+ event_filter = from_layer(transformer.layers[2].attention)
26
+ event_filter = (
27
+ AttentionEvent &
28
+ lambda ev: torch.isnan(ev.attention_matrix).any()
29
+ )
30
+ """
31
+ def __call__(self, event):
32
+ raise NotImplementedError()
33
+
34
+ def _to_event_filter(self, other):
35
+ if isinstance(other, EventFilter):
36
+ return other
37
+ if isinstance(other, type) and issubclass(other, Event):
38
+ return event_class(other)
39
+ if callable(other):
40
+ return CallableEventFilter(other)
41
+
42
+ return NotImplemented
43
+
44
+ def __and__(self, other):
45
+ other = self._to_event_filter(other)
46
+ if other is NotImplemented:
47
+ return other
48
+ return CallableEventFilter(lambda ev: self(ev) and other(ev))
49
+
50
+ def __rand__(self, other):
51
+ other = self._to_event_filter(other)
52
+ if other is NotImplemented:
53
+ return other
54
+ return CallableEventFilter(lambda ev: other(ev) and self(ev))
55
+
56
+ def __or__(self, other):
57
+ other = self._to_event_filter(other)
58
+ if other is NotImplemented:
59
+ return other
60
+ return CallableEventFilter(lambda ev: self(ev) or other(ev))
61
+
62
+ def __ror__(self, other):
63
+ other = self._to_event_filter(other)
64
+ if other is NotImplemented:
65
+ return other
66
+ return CallableEventFilter(lambda ev: other(ev) or self(ev))
67
+
68
+ def __invert__(self):
69
+ return CallableEventFilter(lambda ev: not self(ev))
70
+
71
+
72
+ class CallableEventFilter(EventFilter):
73
+ """Wrap a function with an EventFilter object."""
74
+ def __init__(self, event_filter):
75
+ self._event_filter = event_filter
76
+
77
+ def __call__(self, event):
78
+ return self._event_filter(event)
79
+
80
+
81
+ class LayerNameEventFilter(EventFilter):
82
+ """A LayerNameEventFilter allows to filter events based on a human readable
83
+ name of the layer that emitted them.
84
+
85
+ Note that LayerNameEventFilter keeps a weak reference to all modules which
86
+ means that it cannot be used to prevent modules from being garbage
87
+ collected.
88
+
89
+ Arguments
90
+ ---------
91
+ root: torch.nn.Module instance that represents the root container
92
+ name_filter: callable, that returns true if the name
93
+ """
94
+ def __init__(self, root, name_filter):
95
+ self._names = {
96
+ weakref.ref(m): n
97
+ for n, m in root.named_modules()
98
+ }
99
+ self._name_filter = name_filter
100
+
101
+ def __call__(self, event):
102
+ name = self._names.get(weakref.ref(event.source), None)
103
+ if name is None:
104
+ return False
105
+ return self._name_filter(name)
106
+
107
+
108
+ def event_class(klass):
109
+ """Select events that are instances of `klass`.
110
+
111
+ Arguments
112
+ ---------
113
+ klass: A class to check the event instance against
114
+
115
+ Returns
116
+ -------
117
+ An instance of EventFilter
118
+ """
119
+ return CallableEventFilter(lambda ev: isinstance(ev, klass))
120
+
121
+
122
+ def from_layer(layer):
123
+ """Select events that are dispatched from the `layer`.
124
+
125
+ Arguments
126
+ ---------
127
+ layer: An instance of torch.nn.Module to check against the event source
128
+
129
+ Returns
130
+ -------
131
+ An instance of EventFilter
132
+ """
133
+ return CallableEventFilter(lambda ev: ev.source is layer)
134
+
135
+
136
+ def layer_name_contains(root, name):
137
+ """Select events that contain `name` in their human readable name.
138
+
139
+ We use root.named_modules() to get human readable names for the layers.
140
+ """
141
+ return LayerNameEventFilter(root, lambda n: name in n)
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Angelos Katharopoulos <[email protected]>
4
+ #
5
+
6
+ """Implementations of feature maps to be used with linear attention and causal
7
+ linear attention."""
8
+
9
+
10
+ from .base import elu_feature_map, ActivationFunctionFeatureMap
11
+ from .fourier_features import RandomFourierFeatures, Favor, \
12
+ SmoothedRandomFourierFeatures, GeneralizedRandomFeatures
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (614 Bytes). View file
 
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc ADDED
Binary file (3.42 kB). View file