Spaces:
Running
Running
Enzo Reis de Oliveira
commited on
Commit
·
b60e08a
1
Parent(s):
30f063f
Fixing again
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- smi-ted/inference/smi_ted_light/.gitattributes +2 -0
- smi-ted/inference/smi_ted_light/fast_transformers/__init__.py +15 -0
- smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py +128 -0
- smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py +20 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py +113 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py +116 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py +195 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py +66 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py +88 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py +95 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py +268 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py +257 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py +92 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py +101 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py +166 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py +17 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py +61 -0
- smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py +126 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py +59 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py +139 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py +67 -0
- smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py +550 -0
- smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py +78 -0
- smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
- smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py +115 -0
- smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so +3 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py +10 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/event.py +51 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py +92 -0
- smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py +141 -0
- smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py +12 -0
- smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc +0 -0
- smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc +0 -0
smi-ted/inference/smi_ted_light/.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
smi-ted/inference/smi_ted_light/fast_transformers/**/*.so filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
smi-ted/inference/smi_ted_light/fast_transformers/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Provide a library with fast transformer implementations."""
|
8 |
+
|
9 |
+
__author__ = "Angelos Katharopoulos, Apoorv Vyas"
|
10 |
+
__copyright__ = "Copyright (c) 2020 Idiap Research Institute"
|
11 |
+
__license__ = "MIT"
|
12 |
+
__maintainer__ = "Angelos Katharopoulos, Apoorv Vyas"
|
13 |
+
__email__ = "[email protected], [email protected]"
|
14 |
+
__url__ = "https://github.com/idiap/fast-transformers"
|
15 |
+
__version__ = "0.4.0"
|
smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .aggregate_cpu import aggregate as aggregate_cpu, \
|
11 |
+
broadcast as broadcast_cpu
|
12 |
+
try:
|
13 |
+
from .aggregate_cuda import aggregate as aggregate_gpu, \
|
14 |
+
broadcast as broadcast_gpu
|
15 |
+
from .clustered_aggregate_cuda import \
|
16 |
+
clustered_broadcast as clustered_broadcast_gpu, \
|
17 |
+
clustered_aggregate as clustered_aggregate_gpu
|
18 |
+
|
19 |
+
except ImportError:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
def aggregate(X, G, F, Y=None):
|
24 |
+
device = X.device
|
25 |
+
if Y is None:
|
26 |
+
Y = torch.zeros(
|
27 |
+
F.shape + (X.shape[-1],),
|
28 |
+
device=device,
|
29 |
+
dtype=X.dtype
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
Y.zero_()
|
33 |
+
|
34 |
+
if device.type == "cpu":
|
35 |
+
aggregate_cpu(X, G, F, Y)
|
36 |
+
else:
|
37 |
+
aggregate_gpu(X, G, F, Y)
|
38 |
+
|
39 |
+
return Y
|
40 |
+
|
41 |
+
|
42 |
+
def broadcast(Y, G, F, X=None):
|
43 |
+
device = Y.device
|
44 |
+
if X is None:
|
45 |
+
X = torch.zeros(
|
46 |
+
G.shape + (Y.shape[-1],),
|
47 |
+
device=device,
|
48 |
+
dtype=Y.dtype
|
49 |
+
)
|
50 |
+
|
51 |
+
if device.type == "cpu":
|
52 |
+
broadcast_cpu(Y, G, F, X)
|
53 |
+
else:
|
54 |
+
broadcast_gpu(Y, G, F, X)
|
55 |
+
|
56 |
+
return X
|
57 |
+
|
58 |
+
|
59 |
+
# Divide the cluster into groups of equal size
|
60 |
+
# as constrained by the shared memory
|
61 |
+
def set_group(C, E):
|
62 |
+
C_per_block = int(192 * 64 / (E+1))
|
63 |
+
G_min = (C + C_per_block - 1) // C_per_block
|
64 |
+
for G in range(G_min, C+1):
|
65 |
+
if C % G == 0:
|
66 |
+
return G
|
67 |
+
|
68 |
+
|
69 |
+
def clustered_broadcast(Y, groups, counts, factors, X=None):
|
70 |
+
device = Y.device
|
71 |
+
if X is None:
|
72 |
+
X = torch.zeros(
|
73 |
+
groups.shape + (Y.shape[-1],),
|
74 |
+
device=device,
|
75 |
+
dtype=Y.dtype
|
76 |
+
)
|
77 |
+
if device.type == "cpu":
|
78 |
+
broadcast_cpu(Y, groups, factors, X)
|
79 |
+
else:
|
80 |
+
N, H, C, E = Y.shape
|
81 |
+
_, _, L, _ = X.shape
|
82 |
+
|
83 |
+
# Following are some booking keeping parameters to facilitate the
|
84 |
+
# broadcast kernel that takes advantage of clustering
|
85 |
+
# More information can be found in the cuda file
|
86 |
+
with torch.no_grad():
|
87 |
+
threads = 256
|
88 |
+
G = set_group(C, E)
|
89 |
+
group_counts = counts.view(N, H, G, -1).sum(-1)
|
90 |
+
block_counts = (group_counts + threads - 1) // threads
|
91 |
+
total_blocks = block_counts.sum().item()
|
92 |
+
indx_maps = torch.ones(
|
93 |
+
(total_blocks, 5),
|
94 |
+
device=X.device,
|
95 |
+
dtype=torch.int32
|
96 |
+
)
|
97 |
+
|
98 |
+
clustered_broadcast_gpu(
|
99 |
+
Y,
|
100 |
+
groups,
|
101 |
+
factors,
|
102 |
+
X,
|
103 |
+
block_counts.int(),
|
104 |
+
group_counts.int(),
|
105 |
+
threads,
|
106 |
+
G,
|
107 |
+
total_blocks,
|
108 |
+
indx_maps
|
109 |
+
)
|
110 |
+
return X
|
111 |
+
|
112 |
+
|
113 |
+
def clustered_aggregate(X, G, F, lengths, Y=None):
|
114 |
+
device = X.device
|
115 |
+
if Y is None:
|
116 |
+
Y = torch.zeros(
|
117 |
+
F.shape + (X.shape[-1],),
|
118 |
+
device=device,
|
119 |
+
dtype=X.dtype
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
Y.zero_()
|
123 |
+
|
124 |
+
if device.type == "cpu":
|
125 |
+
aggregate_cpu(X, G, F, Y)
|
126 |
+
else:
|
127 |
+
clustered_aggregate_gpu(X, G, F, lengths, Y)
|
128 |
+
return Y
|
smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bccb1a374d4649aaef6361cc41c9ffb471086464cc07a0d6d21c5b65adb0711
|
3 |
+
size 138248
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implementations of different types of attention mechanisms."""
|
8 |
+
|
9 |
+
|
10 |
+
from .attention_layer import AttentionLayer
|
11 |
+
from .full_attention import FullAttention
|
12 |
+
from .linear_attention import LinearAttention
|
13 |
+
#from .causal_linear_attention import CausalLinearAttention
|
14 |
+
#from .clustered_attention import ClusteredAttention
|
15 |
+
#from .improved_clustered_attention import ImprovedClusteredAttention
|
16 |
+
#from .reformer_attention import ReformerAttention
|
17 |
+
#from .conditional_full_attention import ConditionalFullAttention
|
18 |
+
#from .exact_topk_attention import ExactTopKAttention
|
19 |
+
#from .improved_clustered_causal_attention import ImprovedClusteredCausalAttention
|
20 |
+
#from .local_attention import LocalAttention
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (502 Bytes). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc
ADDED
Binary file (4.14 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc
ADDED
Binary file (3.32 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc
ADDED
Binary file (2.96 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""The base attention layer performs all the query key value projections and
|
8 |
+
output projections leaving the implementation of the attention to the inner
|
9 |
+
attention module.
|
10 |
+
|
11 |
+
The transformer layers, however, are agnostic of the attention implementation
|
12 |
+
and any layer that implements the same interface can substitute for the
|
13 |
+
attention layer.
|
14 |
+
"""
|
15 |
+
|
16 |
+
from torch.nn import Linear, Module
|
17 |
+
|
18 |
+
from ..events import EventDispatcher, QKVEvent
|
19 |
+
|
20 |
+
|
21 |
+
class AttentionLayer(Module):
|
22 |
+
"""Implement the attention layer. Namely project the inputs to multi-head
|
23 |
+
queries, keys and values, call the attention implementation and then
|
24 |
+
reproject the output.
|
25 |
+
|
26 |
+
It can be thought of as a decorator (see decorator design patter) of an
|
27 |
+
attention layer.
|
28 |
+
|
29 |
+
Arguments
|
30 |
+
---------
|
31 |
+
attention: Specific inner attention implementation that just computes a
|
32 |
+
weighted average of values given a similarity of queries and
|
33 |
+
keys.
|
34 |
+
d_model: The input feature dimensionality
|
35 |
+
n_heads: The number of heads for the multi head attention
|
36 |
+
d_keys: The dimensionality of the keys/queries
|
37 |
+
(default: d_model/n_heads)
|
38 |
+
d_values: The dimensionality of the values (default: d_model/n_heads)
|
39 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
40 |
+
module for dispatching events (default: the default
|
41 |
+
global dispatcher)
|
42 |
+
"""
|
43 |
+
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
44 |
+
d_values=None, event_dispatcher=""):
|
45 |
+
super(AttentionLayer, self).__init__()
|
46 |
+
|
47 |
+
# Fill d_keys and d_values
|
48 |
+
d_keys = d_keys or (d_model//n_heads)
|
49 |
+
d_values = d_values or (d_model//n_heads)
|
50 |
+
|
51 |
+
self.inner_attention = attention
|
52 |
+
self.query_projection = Linear(d_model, d_keys * n_heads)
|
53 |
+
self.key_projection = Linear(d_model, d_keys * n_heads)
|
54 |
+
self.value_projection = Linear(d_model, d_values * n_heads)
|
55 |
+
self.out_projection = Linear(d_values * n_heads, d_model)
|
56 |
+
self.n_heads = n_heads
|
57 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
58 |
+
|
59 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
60 |
+
key_lengths):
|
61 |
+
"""Apply attention to the passed in queries/keys/values after
|
62 |
+
projecting them to multiple heads.
|
63 |
+
|
64 |
+
In the argument description we make use of the following sizes
|
65 |
+
|
66 |
+
- N: the batch size
|
67 |
+
- L: The maximum length of the queries
|
68 |
+
- S: The maximum length of the keys (the actual length per sequence
|
69 |
+
is given by the length mask)
|
70 |
+
- D: The input feature dimensionality passed in the constructor as
|
71 |
+
'd_model'
|
72 |
+
|
73 |
+
Arguments
|
74 |
+
---------
|
75 |
+
queries: (N, L, D) The tensor containing the queries
|
76 |
+
keys: (N, S, D) The tensor containing the keys
|
77 |
+
values: (N, S, D) The tensor containing the values
|
78 |
+
attn_mask: An implementation of BaseMask that encodes where each
|
79 |
+
query can attend to
|
80 |
+
query_lengths: An implementation of BaseMask that encodes how
|
81 |
+
many queries each sequence in the batch consists of
|
82 |
+
key_lengths: An implementation of BaseMask that encodes how
|
83 |
+
many queries each sequence in the batch consists of
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
The new value for each query as a tensor of shape (N, L, D).
|
88 |
+
"""
|
89 |
+
# Extract the dimensions into local variables
|
90 |
+
N, L, _ = queries.shape
|
91 |
+
_, S, _ = keys.shape
|
92 |
+
H = self.n_heads
|
93 |
+
|
94 |
+
# Project the queries/keys/values
|
95 |
+
queries = self.query_projection(queries).view(N, L, H, -1)
|
96 |
+
keys = self.key_projection(keys).view(N, S, H, -1)
|
97 |
+
values = self.value_projection(values).view(N, S, H, -1)
|
98 |
+
|
99 |
+
# Let the world know of the qkv
|
100 |
+
self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values))
|
101 |
+
|
102 |
+
# Compute the attention
|
103 |
+
new_values = self.inner_attention(
|
104 |
+
queries,
|
105 |
+
keys,
|
106 |
+
values,
|
107 |
+
attn_mask,
|
108 |
+
query_lengths,
|
109 |
+
key_lengths
|
110 |
+
).view(N, L, -1)
|
111 |
+
|
112 |
+
# Project the output and return
|
113 |
+
return self.out_projection(new_values)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement causally masked linear attention."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.nn import Module
|
11 |
+
|
12 |
+
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
|
13 |
+
EventDispatcherInstance
|
14 |
+
from ..events import EventDispatcher
|
15 |
+
from ..causal_product import causal_dot_product
|
16 |
+
from ..feature_maps import elu_feature_map
|
17 |
+
|
18 |
+
|
19 |
+
def causal_linear(Q, K, V):
|
20 |
+
Q = Q.permute(0,2,1,3).contiguous()
|
21 |
+
K = K.permute(0,2,1,3).contiguous()
|
22 |
+
V = V.permute(0,2,1,3).contiguous()
|
23 |
+
V_new = causal_dot_product(Q, K, V)
|
24 |
+
return V_new.permute(0,2,1,3).contiguous()
|
25 |
+
|
26 |
+
|
27 |
+
class CausalLinearAttention(Module):
|
28 |
+
"""Implement causally masked attention using dot product of feature maps in
|
29 |
+
O(N D^2) complexity.
|
30 |
+
|
31 |
+
See fast_transformers.attention.linear_attention.LinearAttention for the
|
32 |
+
general concept of replacing the softmax with feature maps. In addition to
|
33 |
+
that, we also make use of the fact that causal masking is a triangular mask
|
34 |
+
which allows us to apply the masking and still compute the attention in O(N
|
35 |
+
D^2) complexity.
|
36 |
+
|
37 |
+
Arguments
|
38 |
+
---------
|
39 |
+
feature_map: callable, a callable that applies the feature map to the
|
40 |
+
last dimension of a tensor (default: elu(x)+1)
|
41 |
+
eps: float, a small number to ensure the numerical stability of the
|
42 |
+
denominator (default: 1e-6)
|
43 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
44 |
+
module for dispatching events (default: the default
|
45 |
+
global dispatcher)
|
46 |
+
"""
|
47 |
+
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
|
48 |
+
event_dispatcher=""):
|
49 |
+
super(CausalLinearAttention, self).__init__()
|
50 |
+
self.feature_map = (
|
51 |
+
feature_map(query_dimensions) if feature_map else
|
52 |
+
elu_feature_map(query_dimensions)
|
53 |
+
)
|
54 |
+
self.eps = eps
|
55 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
56 |
+
|
57 |
+
def _make_sizes_compatible(self, Q, K):
|
58 |
+
"""Either slice or pad K in case that the sizes do not match between Q
|
59 |
+
and K."""
|
60 |
+
N, L, H, E = Q.shape
|
61 |
+
_, S, _, _ = K.shape
|
62 |
+
if L == S:
|
63 |
+
return Q, K
|
64 |
+
|
65 |
+
if L < S:
|
66 |
+
return Q, K[:, :L, :, :]
|
67 |
+
|
68 |
+
if L > S:
|
69 |
+
return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1)
|
70 |
+
|
71 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
72 |
+
key_lengths):
|
73 |
+
# Apply the feature map to the queries and keys
|
74 |
+
self.feature_map.new_feature_map(queries.device)
|
75 |
+
Q = self.feature_map.forward_queries(queries)
|
76 |
+
K = self.feature_map.forward_keys(keys)
|
77 |
+
|
78 |
+
# Apply the key padding mask and make sure the attn_mask is a
|
79 |
+
# lower triangular causal mask
|
80 |
+
if not attn_mask.lower_triangular:
|
81 |
+
raise RuntimeError(("CausalLinearAttention only supports full "
|
82 |
+
"lower triangular masks"))
|
83 |
+
K = K * key_lengths.float_matrix[:, :, None, None]
|
84 |
+
|
85 |
+
# Ensure that Q and K have compatible sizes for the following
|
86 |
+
# computations, namely L == S
|
87 |
+
Q, K = self._make_sizes_compatible(Q, K)
|
88 |
+
|
89 |
+
# TODO: Shall we divide the Q and K with a relatively large number to
|
90 |
+
# avoid numerical instabilities in computing the denominator?
|
91 |
+
# We used to divide each with the max norm of all q and k but
|
92 |
+
# that seems relatively costly for a simple normalization.
|
93 |
+
|
94 |
+
# Compute the normalizers
|
95 |
+
Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
|
96 |
+
|
97 |
+
# Compute the unnormalized result
|
98 |
+
V = causal_linear(
|
99 |
+
Q,
|
100 |
+
K,
|
101 |
+
values
|
102 |
+
)
|
103 |
+
|
104 |
+
return V * Z[:, :, :, None]
|
105 |
+
|
106 |
+
|
107 |
+
# Register the attention implementation so that it becomes available in our
|
108 |
+
# builders
|
109 |
+
AttentionRegistry.register(
|
110 |
+
"causal-linear", CausalLinearAttention,
|
111 |
+
[
|
112 |
+
("query_dimensions", Int),
|
113 |
+
("feature_map", Optional(Callable)),
|
114 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
115 |
+
]
|
116 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement clustered self attention."""
|
8 |
+
|
9 |
+
from math import sqrt
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.autograd
|
13 |
+
from torch.nn import Dropout, Module
|
14 |
+
from torch.nn.init import normal_
|
15 |
+
|
16 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
|
17 |
+
Bool, EventDispatcherInstance
|
18 |
+
from ..events import EventDispatcher
|
19 |
+
from ..masking import FullMask
|
20 |
+
from ..aggregate import clustered_aggregate, clustered_broadcast
|
21 |
+
from ..clustering.hamming import cluster
|
22 |
+
from ..hashing import compute_hashes
|
23 |
+
|
24 |
+
|
25 |
+
class _GroupQueries(torch.autograd.Function):
|
26 |
+
@staticmethod
|
27 |
+
def forward(ctx, Q, clusters, counts, lengths):
|
28 |
+
factors = 1./counts.float()
|
29 |
+
q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
|
30 |
+
ctx.save_for_backward(clusters, counts, factors)
|
31 |
+
|
32 |
+
return q_grouped
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def backward(ctx, grad_q_grouped):
|
36 |
+
clusters, counts, factors = ctx.saved_tensors
|
37 |
+
grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
|
38 |
+
|
39 |
+
return grad_q, None, None, None
|
40 |
+
|
41 |
+
|
42 |
+
class _BroadcastValues(torch.autograd.Function):
|
43 |
+
@staticmethod
|
44 |
+
def forward(ctx, v_grouped, clusters, counts, lengths):
|
45 |
+
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
|
46 |
+
V = clustered_broadcast(v_grouped, clusters, counts, factors)
|
47 |
+
ctx.save_for_backward(clusters, counts, factors, lengths)
|
48 |
+
|
49 |
+
return V
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def backward(ctx, grad_v):
|
53 |
+
clusters, counts, factors, lengths = ctx.saved_tensors
|
54 |
+
grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
|
55 |
+
|
56 |
+
return grad_v_grouped, None, None, None
|
57 |
+
|
58 |
+
|
59 |
+
class ClusteredAttention(Module):
|
60 |
+
"""Use LSH and clustering in the resulting Hamming space to group queries
|
61 |
+
that will have minimal L2 distance from each other.
|
62 |
+
|
63 |
+
Given the queries, keys, and values as Q, K, and V respectively, we
|
64 |
+
first cluster the queries in "C" groups and compute the "C" query centroids
|
65 |
+
Q_c.
|
66 |
+
|
67 |
+
We now use to the centroids Q_c to compute the attention using:
|
68 |
+
|
69 |
+
V'_c = softmax(Q_c.mm(K.t()), dim=-1).mm(V).
|
70 |
+
|
71 |
+
Now the computed values V'_c are "broadcasted" back to the query members
|
72 |
+
of the corresponding cluster.
|
73 |
+
|
74 |
+
Arguments
|
75 |
+
---------
|
76 |
+
clusters: How many clusters to group the queries into
|
77 |
+
iterations: The number of lloyd iterations to perform (default: 10)
|
78 |
+
bits: How many bits to use for the hash (default: 32)
|
79 |
+
hash_bias: If true, hamming distance proportional to L2 distance
|
80 |
+
If false, hamming distance proportional to cosine distance
|
81 |
+
(default: True)
|
82 |
+
softmax_temp: The temperature to use for the softmax attention.
|
83 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
84 |
+
runtime)
|
85 |
+
attention_dropout: The dropout rate to apply to the attention
|
86 |
+
(default: 0.1)
|
87 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
88 |
+
module for dispatching events (default: the default
|
89 |
+
global dispatcher)
|
90 |
+
"""
|
91 |
+
def __init__(self, clusters, iterations=10, bits=32,
|
92 |
+
hash_bias=True, softmax_temp=None, attention_dropout=0.1,
|
93 |
+
event_dispatcher=""):
|
94 |
+
super(ClusteredAttention, self).__init__()
|
95 |
+
self.clusters = clusters
|
96 |
+
self.iterations = iterations
|
97 |
+
self.bits = bits
|
98 |
+
self.hash_bias = hash_bias
|
99 |
+
self.softmax_temp = softmax_temp
|
100 |
+
self.dropout = Dropout(attention_dropout)
|
101 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
102 |
+
|
103 |
+
def _create_query_groups(self, Q, query_lengths):
|
104 |
+
N, H, L, E = Q.shape
|
105 |
+
|
106 |
+
# Compute the hashes for all the queries
|
107 |
+
planes = Q.new_empty((self.bits, E+1))
|
108 |
+
normal_(planes)
|
109 |
+
if not self.hash_bias:
|
110 |
+
planes[:, -1] = 0
|
111 |
+
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
|
112 |
+
|
113 |
+
# Cluster the hashes and return the cluster index per query
|
114 |
+
clusters, counts = cluster(
|
115 |
+
hashes,
|
116 |
+
query_lengths._lengths.int(),
|
117 |
+
clusters=self.clusters,
|
118 |
+
iterations=self.iterations,
|
119 |
+
bits=self.bits
|
120 |
+
)
|
121 |
+
sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
|
122 |
+
return (sorted_clusters, counts), sorted_indx
|
123 |
+
|
124 |
+
def _group_queries(self, Q, groups, lengths):
|
125 |
+
"""Aggregate the Qs based on the index of cluster they belong to. Make
|
126 |
+
sure to allow for gradient propagation backwards from the grouped
|
127 |
+
queries to each query."""
|
128 |
+
q_grouped = _GroupQueries.apply(Q, *groups, lengths)
|
129 |
+
return q_grouped
|
130 |
+
|
131 |
+
def _broadcast_values(self, V, groups, lengths):
|
132 |
+
"""Broadcast the values back to the correct positions but make sure
|
133 |
+
that the gradient flows properly."""
|
134 |
+
V_new = _BroadcastValues.apply(V.contiguous(), *groups, lengths)
|
135 |
+
return V_new
|
136 |
+
|
137 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
138 |
+
key_lengths):
|
139 |
+
# Make sure that there is no attention mask
|
140 |
+
assert attn_mask.all_ones, ("Clustered attention cannot use an "
|
141 |
+
"arbitrary attention mask.")
|
142 |
+
|
143 |
+
queries = queries.permute(0,2,1,3).contiguous()
|
144 |
+
keys = keys.permute(0,2,1,3).contiguous()
|
145 |
+
values = values.permute(0,2,1,3).contiguous()
|
146 |
+
|
147 |
+
N, H, L, E = queries.shape
|
148 |
+
_, _, S, D = values.shape
|
149 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
150 |
+
|
151 |
+
# Cluster the queries into groups
|
152 |
+
groups, sorted_indx = self._create_query_groups(queries, query_lengths)
|
153 |
+
# Re-organize queries so that first group belong to first cluster
|
154 |
+
# next to second cluster and so on. This improves kernel implementations.
|
155 |
+
# Note that this step is introduced after NeurIPS submission and
|
156 |
+
# now the complexity is O(N log(N)).
|
157 |
+
q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
|
158 |
+
q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
|
159 |
+
s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
|
160 |
+
|
161 |
+
# Aggregate the re-arranged queries.
|
162 |
+
Q_grouped = self._group_queries(s_queries, groups, query_lengths._lengths.int())
|
163 |
+
# Compute the attention
|
164 |
+
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
|
165 |
+
QK = QK + key_lengths.additive_matrix[:, None, None, :]
|
166 |
+
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
|
167 |
+
V = torch.einsum("nhls,nhsd->nhld", A, values)
|
168 |
+
|
169 |
+
# Broadcast grouped attention
|
170 |
+
V_broadcast = self._broadcast_values(V, groups, query_lengths._lengths.int())
|
171 |
+
|
172 |
+
# Reverse the previous mapping
|
173 |
+
rev_indx = torch.argsort(sorted_indx, dim=-1)
|
174 |
+
q_rev_flat = (rev_indx.view(N*H, -1) + q_offset).reshape(-1)
|
175 |
+
V_new = V_broadcast.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
|
176 |
+
V_new = V_new.permute(0, 2, 1, 3).contiguous()
|
177 |
+
return V_new
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
# Register the attention implementation so that it becomes available in our
|
183 |
+
# builders
|
184 |
+
AttentionRegistry.register(
|
185 |
+
"clustered", ClusteredAttention,
|
186 |
+
[
|
187 |
+
("clusters", Int),
|
188 |
+
("iterations", Optional(Int, 10)),
|
189 |
+
("bits", Optional(Int, 63)),
|
190 |
+
("hash_bias", Optional(Bool, True)),
|
191 |
+
("softmax_temp", Optional(Float)),
|
192 |
+
("attention_dropout", Optional(Float, 0.1)),
|
193 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
194 |
+
]
|
195 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement a self attention that delegates to full attention or another
|
8 |
+
attention depending on the input sequence length."""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.nn import Module
|
12 |
+
|
13 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
14 |
+
EventDispatcherInstance
|
15 |
+
from ..events import EventDispatcher
|
16 |
+
from .full_attention import FullAttention
|
17 |
+
|
18 |
+
|
19 |
+
class ConditionalFullAttention(Module):
|
20 |
+
""""Delegate to full attention if the input sequence is short.
|
21 |
+
|
22 |
+
Arguments
|
23 |
+
---------
|
24 |
+
other_attention: Use the passed attention module if the sequence is
|
25 |
+
longer than 'length_limit'.
|
26 |
+
length_limit: An integer denoting the maximum sequence length to
|
27 |
+
consider.
|
28 |
+
softmax_temp: See fast_transformers.attention.full_attention.
|
29 |
+
attention_dropout: See fast_transformers.attention.full_attention.
|
30 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
31 |
+
module for dispatching events (default: the default
|
32 |
+
global dispatcher)
|
33 |
+
"""
|
34 |
+
def __init__(self, other_attention, length_limit=512, softmax_temp=None,
|
35 |
+
attention_dropout=0.1, event_dispatcher=""):
|
36 |
+
super(ConditionalFullAttention, self).__init__()
|
37 |
+
self.full_attention = FullAttention(softmax_temp, attention_dropout)
|
38 |
+
self.other_attention = other_attention
|
39 |
+
self.length_limit = length_limit
|
40 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
41 |
+
|
42 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
43 |
+
key_lengths):
|
44 |
+
# Extract some shapes to compare with the length limit
|
45 |
+
L = queries.shape[1]
|
46 |
+
S = values.shape[1]
|
47 |
+
|
48 |
+
if L > self.length_limit or S > self.length_limit:
|
49 |
+
return self.other_attention(queries, keys, values, attn_mask,
|
50 |
+
query_lengths, key_lengths)
|
51 |
+
else:
|
52 |
+
return self.full_attention(queries, keys, values, attn_mask,
|
53 |
+
query_lengths, key_lengths)
|
54 |
+
|
55 |
+
|
56 |
+
# Register the attention implementation so that it becomes available in our
|
57 |
+
# builders
|
58 |
+
AttentionRegistry.register(
|
59 |
+
"conditional-full", ConditionalFullAttention,
|
60 |
+
[
|
61 |
+
("length_limit", Optional(Int, 512)),
|
62 |
+
("softmax_temp", Optional(Float)),
|
63 |
+
("attention_dropout", Optional(Float, 0.1)),
|
64 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
65 |
+
]
|
66 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement the oracle top-k attention. The top-k keys are exact ones.
|
8 |
+
MultiHeadAttention module. Note that this module is to be used in conjuction
|
9 |
+
with the AttentionLayer in order to work."""
|
10 |
+
|
11 |
+
from math import sqrt
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.nn import Dropout, Module
|
15 |
+
|
16 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
17 |
+
EventDispatcherInstance
|
18 |
+
from ..events import EventDispatcher
|
19 |
+
|
20 |
+
|
21 |
+
class ExactTopKAttention(Module):
|
22 |
+
"""Implement the oracle top-k softmax attention.
|
23 |
+
|
24 |
+
Arguments
|
25 |
+
---------
|
26 |
+
top-k: The top k keys to attend to (default: 32)
|
27 |
+
softmax_temp: The temperature to use for the softmax attention.
|
28 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
29 |
+
runtime)
|
30 |
+
attention_dropout: The dropout rate to apply to the attention
|
31 |
+
(default: 0.1)
|
32 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
33 |
+
module for dispatching events (default: the default
|
34 |
+
global dispatcher)
|
35 |
+
"""
|
36 |
+
def __init__(self, topk=32, softmax_temp=None, attention_dropout=0.1,
|
37 |
+
event_dispatcher=""):
|
38 |
+
super(ExactTopKAttention, self).__init__()
|
39 |
+
self.topk = topk
|
40 |
+
self.softmax_temp = softmax_temp
|
41 |
+
self.dropout = Dropout(attention_dropout)
|
42 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
43 |
+
|
44 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
45 |
+
key_lengths):
|
46 |
+
# Extract some shapes and compute the temperature
|
47 |
+
N, L, H, E = queries.shape
|
48 |
+
_, S, _, D = values.shape
|
49 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
50 |
+
|
51 |
+
# Compute the unnormalized attention and apply the masks
|
52 |
+
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
|
53 |
+
topk = min(self.topk, S)
|
54 |
+
|
55 |
+
if not attn_mask.all_ones:
|
56 |
+
QK = QK + attn_mask.additive_matrix
|
57 |
+
QK = QK + key_lengths.additive_matrix[:, None, None]
|
58 |
+
|
59 |
+
topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1)
|
60 |
+
mask = QK.new_ones(QK.shape) * float("-inf")
|
61 |
+
mask[
|
62 |
+
torch.arange(N, device=QK.device).view(N, 1, 1, 1),
|
63 |
+
torch.arange(H, device=QK.device).view(1, H, 1, 1),
|
64 |
+
torch.arange(L, device=QK.device).view(1, 1, L, 1),
|
65 |
+
topk_idx,
|
66 |
+
] = 0.
|
67 |
+
|
68 |
+
QK = QK + mask
|
69 |
+
|
70 |
+
# Compute the attention and the weighted average
|
71 |
+
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
|
72 |
+
V = torch.einsum("nhls,nshd->nlhd", A, values)
|
73 |
+
|
74 |
+
# Make sure that what we return is contiguous
|
75 |
+
return V.contiguous()
|
76 |
+
|
77 |
+
|
78 |
+
# Register the attention implementation so that it becomes available in our
|
79 |
+
# builders
|
80 |
+
AttentionRegistry.register(
|
81 |
+
"exact-topk", ExactTopKAttention,
|
82 |
+
[
|
83 |
+
("topk", Optional(Int, 32)),
|
84 |
+
("softmax_temp", Optional(Float)),
|
85 |
+
("attention_dropout", Optional(Float, 0.1)),
|
86 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
87 |
+
]
|
88 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement the full attention similar to the one implemented by PyTorch's
|
8 |
+
MultiHeadAttention module. Note that this module is to be used in conjuction
|
9 |
+
with the `fast_transformers.attention.attention_layer.AttentionLayer` in order
|
10 |
+
to work."""
|
11 |
+
|
12 |
+
from math import sqrt
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch.nn import Dropout, Module
|
16 |
+
|
17 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, \
|
18 |
+
EventDispatcherInstance
|
19 |
+
from ..events import EventDispatcher, AttentionEvent
|
20 |
+
|
21 |
+
|
22 |
+
class FullAttention(Module):
|
23 |
+
"""Implement the scaled dot product attention with softmax.
|
24 |
+
|
25 |
+
Arguments
|
26 |
+
---------
|
27 |
+
softmax_temp: The temperature to use for the softmax attention.
|
28 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
29 |
+
runtime)
|
30 |
+
attention_dropout: The dropout rate to apply to the attention
|
31 |
+
(default: 0.1)
|
32 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
33 |
+
module for dispatching events (default: the default
|
34 |
+
global dispatcher)
|
35 |
+
"""
|
36 |
+
def __init__(self, softmax_temp=None, attention_dropout=0.1,
|
37 |
+
event_dispatcher=""):
|
38 |
+
super(FullAttention, self).__init__()
|
39 |
+
self.softmax_temp = softmax_temp
|
40 |
+
self.dropout = Dropout(attention_dropout)
|
41 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
42 |
+
|
43 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
44 |
+
key_lengths):
|
45 |
+
"""Implements the multihead softmax attention.
|
46 |
+
|
47 |
+
Arguments
|
48 |
+
---------
|
49 |
+
queries: (N, L, H, E) The tensor containing the queries
|
50 |
+
keys: (N, S, H, E) The tensor containing the keys
|
51 |
+
values: (N, S, H, D) The tensor containing the values
|
52 |
+
attn_mask: An implementation of BaseMask that encodes where each
|
53 |
+
query can attend to
|
54 |
+
query_lengths: An implementation of BaseMask that encodes how
|
55 |
+
many queries each sequence in the batch consists of
|
56 |
+
key_lengths: An implementation of BaseMask that encodes how
|
57 |
+
many queries each sequence in the batch consists of
|
58 |
+
"""
|
59 |
+
# Extract some shapes and compute the temperature
|
60 |
+
N, L, H, E = queries.shape
|
61 |
+
_, S, _, D = values.shape
|
62 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
63 |
+
|
64 |
+
# Scale the queries instead of applying the softmax temperature to the
|
65 |
+
# dot products
|
66 |
+
queries = queries * softmax_temp
|
67 |
+
|
68 |
+
# Compute the unnormalized attention and apply the masks
|
69 |
+
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
|
70 |
+
if not attn_mask.all_ones:
|
71 |
+
QK = QK + attn_mask.additive_matrix
|
72 |
+
if not key_lengths.all_ones:
|
73 |
+
QK = QK + key_lengths.additive_matrix[:, None, None]
|
74 |
+
|
75 |
+
# Compute the attention and the weighted average
|
76 |
+
A = self.dropout(torch.softmax(QK, dim=-1))
|
77 |
+
V = torch.einsum("nhls,nshd->nlhd", A, values)
|
78 |
+
|
79 |
+
# Let the world know of the attention matrix
|
80 |
+
self.event_dispatcher.dispatch(AttentionEvent(self, A))
|
81 |
+
|
82 |
+
# Make sure that what we return is contiguous
|
83 |
+
return V.contiguous()
|
84 |
+
|
85 |
+
|
86 |
+
# Register the attention implementation so that it becomes available in our
|
87 |
+
# builders
|
88 |
+
AttentionRegistry.register(
|
89 |
+
"full", FullAttention,
|
90 |
+
[
|
91 |
+
("softmax_temp", Optional(Float)),
|
92 |
+
("attention_dropout", Optional(Float, 0.1)),
|
93 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
94 |
+
]
|
95 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement improved clustered self attention."""
|
8 |
+
|
9 |
+
from math import sqrt
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.autograd
|
13 |
+
from torch.nn import Dropout, Module
|
14 |
+
from torch.nn.init import normal_
|
15 |
+
|
16 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
|
17 |
+
Bool, EventDispatcherInstance
|
18 |
+
from ..events import EventDispatcher
|
19 |
+
from ..masking import FullMask
|
20 |
+
from ..aggregate import clustered_aggregate, clustered_broadcast
|
21 |
+
from ..clustering.hamming import cluster
|
22 |
+
from ..hashing import compute_hashes
|
23 |
+
from ..sparse_product import sparse_dot_product, sparse_weighted_average
|
24 |
+
from ..sparse_product import clustered_sparse_dot_product, \
|
25 |
+
clustered_sparse_weighted_average
|
26 |
+
|
27 |
+
|
28 |
+
class _GroupQueries(torch.autograd.Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx, Q, clusters, counts, lengths):
|
31 |
+
factors = 1./counts.float()
|
32 |
+
q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
|
33 |
+
ctx.save_for_backward(clusters, counts, factors)
|
34 |
+
|
35 |
+
return q_grouped
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def backward(ctx, grad_q_grouped):
|
39 |
+
clusters, counts, factors = ctx.saved_tensors
|
40 |
+
grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
|
41 |
+
|
42 |
+
return grad_q, None, None, None
|
43 |
+
|
44 |
+
|
45 |
+
class _BroadcastValues(torch.autograd.Function):
|
46 |
+
@staticmethod
|
47 |
+
def forward(ctx, v_grouped, clusters, counts, lengths):
|
48 |
+
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
|
49 |
+
V = clustered_broadcast(v_grouped, clusters, counts, factors)
|
50 |
+
ctx.save_for_backward(clusters, counts, factors, lengths)
|
51 |
+
|
52 |
+
return V
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, grad_v):
|
56 |
+
clusters, counts, factors, lengths = ctx.saved_tensors
|
57 |
+
grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
|
58 |
+
|
59 |
+
return grad_v_grouped, None, None, None, None
|
60 |
+
|
61 |
+
|
62 |
+
class ImprovedClusteredAttention(Module):
|
63 |
+
"""
|
64 |
+
Immproved clustered attention approximation by recompution attention
|
65 |
+
for each query with the top-k keys for the corresponding cluster.
|
66 |
+
|
67 |
+
Given the queries, keys, and values as Q, K, and V respectively, we
|
68 |
+
first cluster the queries in "C" groups and compute the "C" query centroids
|
69 |
+
Q_c.
|
70 |
+
|
71 |
+
We now use to the centroids Q_c to identify the top-k keys with highest
|
72 |
+
dot products.
|
73 |
+
|
74 |
+
Subsequently, for each query we compute the sparse dot product with
|
75 |
+
the corresponding top-k keys to improve the attention approximation.
|
76 |
+
|
77 |
+
Arguments
|
78 |
+
---------
|
79 |
+
clusters: How many clusters to group the queries into
|
80 |
+
iterations: The number of lloyd iterations to perform (default: 10)
|
81 |
+
bits: How many bits to use for the hash (default: 32)
|
82 |
+
hash_bias: If true, hamming distance proportional to L2 distance
|
83 |
+
If false, hamming distance proportional to cosine distance
|
84 |
+
(default: True)
|
85 |
+
topk: Number of top-k keys to for improved approximation (default: 32)
|
86 |
+
softmax_temp: The temperature to use for the softmax attention.
|
87 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
88 |
+
runtime)
|
89 |
+
attention_dropout: The dropout rate to apply to the attention
|
90 |
+
(default: 0.1)
|
91 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
92 |
+
module for dispatching events (default: the default
|
93 |
+
global dispatcher)
|
94 |
+
"""
|
95 |
+
def __init__(self, clusters, iterations=10, bits=32,
|
96 |
+
hash_bias=True, topk=32, softmax_temp=None,
|
97 |
+
attention_dropout=0.1, event_dispatcher=""):
|
98 |
+
super(ImprovedClusteredAttention, self).__init__()
|
99 |
+
self.clusters = clusters
|
100 |
+
self.iterations = iterations
|
101 |
+
self.bits = bits
|
102 |
+
self.hash_bias = hash_bias
|
103 |
+
self.topk = topk
|
104 |
+
self.softmax_temp = softmax_temp
|
105 |
+
self.dropout = Dropout(attention_dropout)
|
106 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
107 |
+
|
108 |
+
def _create_query_groups(self, Q, query_lengths):
|
109 |
+
N, H, L, E = Q.shape
|
110 |
+
|
111 |
+
# Compute the hashes for all the queries
|
112 |
+
planes = Q.new_empty((self.bits, E+1))
|
113 |
+
normal_(planes)
|
114 |
+
if not self.hash_bias:
|
115 |
+
planes[:, -1] = 0
|
116 |
+
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
|
117 |
+
|
118 |
+
# Cluster the hashes and return the cluster index per query
|
119 |
+
clusters, counts = cluster(
|
120 |
+
hashes,
|
121 |
+
query_lengths._lengths.int(),
|
122 |
+
clusters=self.clusters,
|
123 |
+
iterations=self.iterations,
|
124 |
+
bits=self.bits
|
125 |
+
)
|
126 |
+
sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
|
127 |
+
return (sorted_clusters, counts), sorted_indx
|
128 |
+
|
129 |
+
def _topk_attention(self, Q, K, V,
|
130 |
+
clusters, counts,
|
131 |
+
topk, topk_values,
|
132 |
+
A_bottomk, softmax_temp,
|
133 |
+
query_lengths):
|
134 |
+
"""Return the attention with just the topk heads."""
|
135 |
+
# Extract some indices
|
136 |
+
N, H, L, E = Q.shape
|
137 |
+
_, _, S, _ = K.shape
|
138 |
+
_, _, C, k = topk.shape
|
139 |
+
|
140 |
+
# We need to pass the output tensor to initialize to 0
|
141 |
+
QK = clustered_sparse_dot_product(
|
142 |
+
Q, K, topk,
|
143 |
+
clusters, counts,
|
144 |
+
query_lengths._lengths.int()
|
145 |
+
)
|
146 |
+
# We need to mask the topk dot products if topk > input_length
|
147 |
+
QK = QK.masked_fill(
|
148 |
+
torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k),
|
149 |
+
float("-inf")
|
150 |
+
)
|
151 |
+
A = torch.softmax(softmax_temp * QK, dim=-1)
|
152 |
+
assert A_bottomk.is_contiguous()
|
153 |
+
A_bottomk = clustered_broadcast(
|
154 |
+
A_bottomk.unsqueeze(3),
|
155 |
+
clusters,
|
156 |
+
counts,
|
157 |
+
torch.ones_like(counts, dtype=torch.float32)
|
158 |
+
)
|
159 |
+
A = A * (1.0 - A_bottomk)
|
160 |
+
A = self.dropout(A)
|
161 |
+
assert A.is_contiguous()
|
162 |
+
V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts)
|
163 |
+
|
164 |
+
return V_new
|
165 |
+
|
166 |
+
def _broadcast_values(self, V, clusters, counts, lengths):
|
167 |
+
"""Broadcast the values back to the correct positions but make sure
|
168 |
+
that the gradient flows properly."""
|
169 |
+
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths)
|
170 |
+
return V_new
|
171 |
+
|
172 |
+
def _bottomk_attention(self, QK, V, clusters, counts, query_lengths, topk, softmax_temp):
|
173 |
+
"""Return the attention with just the bottomk keys."""
|
174 |
+
N, H, C, S = QK.shape
|
175 |
+
|
176 |
+
A = torch.softmax(softmax_temp * QK, dim=-1)
|
177 |
+
mask = QK.new_ones(QK.shape)
|
178 |
+
mask[
|
179 |
+
torch.arange(N, device=QK.device).view(N, 1, 1, 1),
|
180 |
+
torch.arange(H, device=QK.device).view(1, H, 1, 1),
|
181 |
+
torch.arange(C, device=QK.device).view(1, 1, C, 1),
|
182 |
+
topk,
|
183 |
+
] = 0
|
184 |
+
A = A * mask
|
185 |
+
A_bottomk = A.sum(-1)
|
186 |
+
A = self.dropout(A)
|
187 |
+
# Compute the values
|
188 |
+
V_new = torch.einsum("nhls,nhse->nhle", A, V)
|
189 |
+
# Broadcast the values back depending on the groups
|
190 |
+
V_new = self._broadcast_values(V_new, clusters, counts, query_lengths._lengths.int())
|
191 |
+
|
192 |
+
return V_new, A_bottomk
|
193 |
+
|
194 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
195 |
+
key_lengths):
|
196 |
+
# Make sure that there is no attention mask
|
197 |
+
assert attn_mask.all_ones, ("Improved-clustered attention cannot "
|
198 |
+
"use an arbitrary attention mask.")
|
199 |
+
|
200 |
+
queries = queries.permute(0,2,1,3).contiguous()
|
201 |
+
keys = keys.permute(0,2,1,3).contiguous()
|
202 |
+
values = values.permute(0,2,1,3).contiguous()
|
203 |
+
N, H, L, E = queries.shape
|
204 |
+
_, _, S, D = values.shape
|
205 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
206 |
+
|
207 |
+
# Cluster the queries into groups
|
208 |
+
groups, sorted_indx = self._create_query_groups(queries, query_lengths)
|
209 |
+
clusters, counts = groups
|
210 |
+
|
211 |
+
# Re-organize queries so that first group belong to first cluster
|
212 |
+
# next to second cluster and so on. This improves kernel implementations.
|
213 |
+
# Note that this step is introduced after NeurIPS submission and
|
214 |
+
# now the complexity is O(N log(N)).
|
215 |
+
q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
|
216 |
+
q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
|
217 |
+
s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
|
218 |
+
|
219 |
+
# Aggregate the re-arranged queries.
|
220 |
+
Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int())
|
221 |
+
# Compute the attention
|
222 |
+
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
|
223 |
+
QK = QK + key_lengths.additive_matrix[:, None, None, :]
|
224 |
+
topk_values, topk = torch.topk(QK, min(self.topk, S), sorted=False, dim=-1)
|
225 |
+
assert topk.is_contiguous()
|
226 |
+
|
227 |
+
# Now compute the attention with only the bottom keys
|
228 |
+
V_bottomk, A_bottomk = self._bottomk_attention(
|
229 |
+
QK, values,
|
230 |
+
clusters, counts,
|
231 |
+
query_lengths,
|
232 |
+
topk,
|
233 |
+
softmax_temp
|
234 |
+
)
|
235 |
+
|
236 |
+
# Now compute the attention with only the top keys
|
237 |
+
V_topk = self._topk_attention(
|
238 |
+
s_queries, keys, values,
|
239 |
+
clusters, counts,
|
240 |
+
topk, topk_values,
|
241 |
+
A_bottomk,
|
242 |
+
softmax_temp,
|
243 |
+
query_lengths
|
244 |
+
)
|
245 |
+
V_sorted_new = V_topk + V_bottomk
|
246 |
+
|
247 |
+
# Reverse the previous mapping
|
248 |
+
sorted_rev_indx = torch.argsort(sorted_indx, dim=-1)
|
249 |
+
q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1)
|
250 |
+
V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
|
251 |
+
return V_new.permute(0, 2, 1, 3).contiguous()
|
252 |
+
|
253 |
+
|
254 |
+
# Register the attention implementation so that it becomes available in our
|
255 |
+
# builders
|
256 |
+
AttentionRegistry.register(
|
257 |
+
"improved-clustered", ImprovedClusteredAttention,
|
258 |
+
[
|
259 |
+
("clusters", Int),
|
260 |
+
("iterations", Optional(Int, 10)),
|
261 |
+
("bits", Optional(Int, 63)),
|
262 |
+
("hash_bias", Optional(Bool, True)),
|
263 |
+
("topk", Optional(Int, 32)),
|
264 |
+
("softmax_temp", Optional(Float)),
|
265 |
+
("attention_dropout", Optional(Float, 0.1)),
|
266 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
267 |
+
]
|
268 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement improved clustered causal self attention."""
|
8 |
+
|
9 |
+
from math import sqrt
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.autograd
|
13 |
+
from torch.nn import Dropout, Module
|
14 |
+
from torch.nn.init import normal_
|
15 |
+
|
16 |
+
from ..attention_registry import AttentionRegistry, Optional, Float, Int, \
|
17 |
+
Bool, EventDispatcherInstance
|
18 |
+
from ..events import EventDispatcher
|
19 |
+
from ..masking import FullMask
|
20 |
+
from ..aggregate import clustered_aggregate, clustered_broadcast
|
21 |
+
from ..clustering.hamming import cluster
|
22 |
+
from ..hashing import compute_hashes
|
23 |
+
from ..sparse_product import sparse_dot_product, sparse_weighted_average
|
24 |
+
from ..sparse_product import clustered_sparse_dot_product, \
|
25 |
+
clustered_sparse_weighted_average
|
26 |
+
|
27 |
+
|
28 |
+
class _GroupQueries(torch.autograd.Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx, Q, clusters, counts, lengths):
|
31 |
+
factors = 1./counts.float()
|
32 |
+
q_grouped = clustered_aggregate(Q, clusters, factors, lengths)
|
33 |
+
ctx.save_for_backward(clusters, counts, factors)
|
34 |
+
|
35 |
+
return q_grouped
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def backward(ctx, grad_q_grouped):
|
39 |
+
clusters, counts, factors = ctx.saved_tensors
|
40 |
+
grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors)
|
41 |
+
|
42 |
+
return grad_q, None, None, None
|
43 |
+
|
44 |
+
|
45 |
+
class _BroadcastValues(torch.autograd.Function):
|
46 |
+
@staticmethod
|
47 |
+
def forward(ctx, v_grouped, clusters, counts, lengths):
|
48 |
+
factors = torch.ones_like(counts, dtype=v_grouped.dtype)
|
49 |
+
V = clustered_broadcast(v_grouped, clusters, counts, factors)
|
50 |
+
ctx.save_for_backward(clusters, counts, factors, lengths)
|
51 |
+
|
52 |
+
return V
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, grad_v):
|
56 |
+
clusters, counts, factors, lengths = ctx.saved_tensors
|
57 |
+
grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths)
|
58 |
+
|
59 |
+
return grad_v_grouped, None, None, None, None
|
60 |
+
|
61 |
+
|
62 |
+
class ImprovedClusteredCausalAttention(Module):
|
63 |
+
"""
|
64 |
+
Immproved clustered causal attention approximation by recomputing attention
|
65 |
+
for each query with the top-k keys for the corresponding cluster.
|
66 |
+
|
67 |
+
Given the queries, keys, and values as Q, K, and V respectively, we
|
68 |
+
first cluster the queries in "C" groups and compute the "C" query centroids
|
69 |
+
Q_c.
|
70 |
+
|
71 |
+
We now use to the centroids Q_c to identify the top-k keys with highest
|
72 |
+
dot products.
|
73 |
+
|
74 |
+
Subsequently, for each query we compute the sparse dot product with
|
75 |
+
the corresponding top-k keys to improve the attention approximation.
|
76 |
+
|
77 |
+
Key difference with improved clustered attention is that we only use
|
78 |
+
top-k keys with causal mask, we do not compute attention on the
|
79 |
+
bottom-k keys.
|
80 |
+
|
81 |
+
Arguments
|
82 |
+
---------
|
83 |
+
clusters: How many clusters to group the queries into
|
84 |
+
iterations: The number of lloyd iterations to perform (default: 10)
|
85 |
+
bits: How many bits to use for the hash (default: 32)
|
86 |
+
hash_bias: If true, hamming distance proportional to L2 distance
|
87 |
+
If false, hamming distance proportional to cosine distance
|
88 |
+
(default: True)
|
89 |
+
topk: Number of top-k keys to for improved approximation (default: 32)
|
90 |
+
softmax_temp: The temperature to use for the softmax attention.
|
91 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
92 |
+
runtime)
|
93 |
+
attention_dropout: The dropout rate to apply to the attention
|
94 |
+
(default: 0.1)
|
95 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
96 |
+
module for dispatching events (default: the default
|
97 |
+
global dispatcher)
|
98 |
+
"""
|
99 |
+
def __init__(self, clusters, iterations=10, bits=32,
|
100 |
+
hash_bias=True, topk=32, softmax_temp=None,
|
101 |
+
attention_dropout=0.1, event_dispatcher=""):
|
102 |
+
super(ImprovedClusteredCausalAttention, self).__init__()
|
103 |
+
self.clusters = clusters
|
104 |
+
self.iterations = iterations
|
105 |
+
self.bits = bits
|
106 |
+
self.hash_bias = hash_bias
|
107 |
+
self.topk = topk
|
108 |
+
self.softmax_temp = softmax_temp
|
109 |
+
self.dropout = Dropout(attention_dropout)
|
110 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
111 |
+
|
112 |
+
def _create_query_groups(self, Q, query_lengths):
|
113 |
+
N, H, L, E = Q.shape
|
114 |
+
|
115 |
+
# Compute the hashes for all the queries
|
116 |
+
planes = Q.new_empty((self.bits, E+1))
|
117 |
+
normal_(planes)
|
118 |
+
if not self.hash_bias:
|
119 |
+
planes[:, -1] = 0
|
120 |
+
hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L)
|
121 |
+
|
122 |
+
# Cluster the hashes and return the cluster index per query
|
123 |
+
clusters, counts = cluster(
|
124 |
+
hashes,
|
125 |
+
query_lengths.lengths.int(),
|
126 |
+
clusters=self.clusters,
|
127 |
+
iterations=self.iterations,
|
128 |
+
bits=self.bits
|
129 |
+
)
|
130 |
+
sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1)
|
131 |
+
return (sorted_clusters, counts), sorted_indx
|
132 |
+
|
133 |
+
def _topk_attention(self, Q, K, V,
|
134 |
+
q_flat, q_rev_flat,
|
135 |
+
clusters, counts,
|
136 |
+
topk, topk_values,
|
137 |
+
softmax_temp,
|
138 |
+
query_lengths):
|
139 |
+
"""Return the attention with just the topk heads."""
|
140 |
+
# Extract some indices
|
141 |
+
N, H, L, E = Q.shape
|
142 |
+
_, _, S, _ = K.shape
|
143 |
+
_, _, C, k = topk.shape
|
144 |
+
|
145 |
+
# We need to pass the output tensor to initialize to 0
|
146 |
+
QK = clustered_sparse_dot_product(
|
147 |
+
Q, K, topk,
|
148 |
+
clusters, counts,
|
149 |
+
query_lengths.lengths.int()
|
150 |
+
)
|
151 |
+
# We need to mask out the future
|
152 |
+
assert topk.is_contiguous()
|
153 |
+
topk_broadcast = clustered_broadcast(
|
154 |
+
topk.float(),
|
155 |
+
clusters,
|
156 |
+
counts,
|
157 |
+
torch.ones_like(counts, dtype=torch.float32)
|
158 |
+
)
|
159 |
+
# Need to be careful here we changed the order of the keys the
|
160 |
+
# masking on future needs to be applied in the same way
|
161 |
+
seq_ids = torch.arange(L, device=QK.device).view(1, 1, L, 1).repeat(N, H, 1, 1)
|
162 |
+
# permute the ids in the same way as input so as to mask the right
|
163 |
+
# entries for each query
|
164 |
+
s_seq_ids = seq_ids.reshape(-1, 1).index_select(0, q_flat).view(N,H,L,1)
|
165 |
+
future_mask = topk_broadcast.long() > s_seq_ids
|
166 |
+
QK = QK.masked_fill(
|
167 |
+
future_mask,
|
168 |
+
float("-1e7")
|
169 |
+
)
|
170 |
+
A = torch.softmax(softmax_temp * QK, dim=-1)
|
171 |
+
# Mask again to ensure no probabilities leak due to float(-1e7)
|
172 |
+
# Leakage could be very high as we use a small top-k
|
173 |
+
A = A * (1. - future_mask.float())
|
174 |
+
A = self.dropout(A)
|
175 |
+
assert A.is_contiguous()
|
176 |
+
V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts)
|
177 |
+
|
178 |
+
return V_new
|
179 |
+
|
180 |
+
def _broadcast_values(self, V, clusters, counts, lengths):
|
181 |
+
"""Broadcast the values back to the correct positions but make sure
|
182 |
+
that the gradient flows properly."""
|
183 |
+
V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths)
|
184 |
+
return V_new
|
185 |
+
|
186 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
187 |
+
key_lengths):
|
188 |
+
|
189 |
+
# Apply the key padding mask and make sure the attn_mask is a
|
190 |
+
# lower triangular causal mask
|
191 |
+
if not attn_mask.lower_triangular:
|
192 |
+
raise RuntimeError(("ImprovedClusteredCausalAttention only supports "
|
193 |
+
"lower triangular masks"))
|
194 |
+
queries = queries.permute(0,2,1,3).contiguous()
|
195 |
+
keys = keys.permute(0,2,1,3).contiguous()
|
196 |
+
values = values.permute(0,2,1,3).contiguous()
|
197 |
+
N, H, L, E = queries.shape
|
198 |
+
_, _, S, D = values.shape
|
199 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
200 |
+
|
201 |
+
# Cluster the queries into groups
|
202 |
+
groups, sorted_indx = self._create_query_groups(queries, query_lengths)
|
203 |
+
clusters, counts = groups
|
204 |
+
|
205 |
+
# Re-organize queries so that first group belong to first cluster
|
206 |
+
# next to second cluster and so on. This improves kernel implementations.
|
207 |
+
# Note that this step is introduced after NeurIPS submission and
|
208 |
+
# now the complexity is O(N log(N)).
|
209 |
+
q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L
|
210 |
+
q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1)
|
211 |
+
s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E)
|
212 |
+
|
213 |
+
# Aggregate the re-arranged queries.
|
214 |
+
Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int())
|
215 |
+
# Compute the attention
|
216 |
+
QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys)
|
217 |
+
QK = QK + key_lengths.additive_matrix[:, None, None, :]
|
218 |
+
# Set topk to minimum of key lengths if it is smaller than self.topk
|
219 |
+
cur_topk = min(self.topk, min(key_lengths.lengths).item())
|
220 |
+
topk_values, topk = torch.topk(QK, cur_topk, sorted=False, dim=-1)
|
221 |
+
assert topk.is_contiguous()
|
222 |
+
|
223 |
+
# Reverse mapping
|
224 |
+
sorted_rev_indx = torch.argsort(sorted_indx, dim=-1)
|
225 |
+
q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1)
|
226 |
+
|
227 |
+
# Compute the attention with only the top keys
|
228 |
+
V_topk = self._topk_attention(
|
229 |
+
s_queries, keys, values,
|
230 |
+
q_flat, q_rev_flat,
|
231 |
+
clusters, counts,
|
232 |
+
topk, topk_values,
|
233 |
+
softmax_temp,
|
234 |
+
query_lengths
|
235 |
+
)
|
236 |
+
V_sorted_new = V_topk
|
237 |
+
|
238 |
+
# Reverse the mapping to get correct values
|
239 |
+
V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D)
|
240 |
+
return V_new.permute(0, 2, 1, 3).contiguous()
|
241 |
+
|
242 |
+
|
243 |
+
# Register the attention implementation so that it becomes available in our
|
244 |
+
# builders
|
245 |
+
AttentionRegistry.register(
|
246 |
+
"causal-improved-clustered", ImprovedClusteredCausalAttention,
|
247 |
+
[
|
248 |
+
("clusters", Int),
|
249 |
+
("iterations", Optional(Int, 10)),
|
250 |
+
("bits", Optional(Int, 63)),
|
251 |
+
("hash_bias", Optional(Bool, True)),
|
252 |
+
("topk", Optional(Int, 32)),
|
253 |
+
("softmax_temp", Optional(Float)),
|
254 |
+
("attention_dropout", Optional(Float, 0.1)),
|
255 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
256 |
+
]
|
257 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement unmasked linear attention."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.nn import Module
|
11 |
+
|
12 |
+
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
|
13 |
+
EventDispatcherInstance
|
14 |
+
from ..events import EventDispatcher
|
15 |
+
from ..feature_maps import elu_feature_map
|
16 |
+
|
17 |
+
|
18 |
+
class LinearAttention(Module):
|
19 |
+
"""Implement unmasked attention using dot product of feature maps in
|
20 |
+
O(N D^2) complexity.
|
21 |
+
|
22 |
+
Given the queries, keys and values as Q, K, V instead of computing
|
23 |
+
|
24 |
+
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
|
25 |
+
|
26 |
+
we make use of a feature map function Φ(.) and perform the following
|
27 |
+
computation
|
28 |
+
|
29 |
+
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
|
30 |
+
|
31 |
+
The above can be computed in O(N D^2) complexity where D is the
|
32 |
+
dimensionality of Q, K and V and N is the sequence length. Depending on the
|
33 |
+
feature map, however, the complexity of the attention might be limited.
|
34 |
+
|
35 |
+
Arguments
|
36 |
+
---------
|
37 |
+
feature_map: callable, a callable that applies the feature map to the
|
38 |
+
last dimension of a tensor (default: elu(x)+1)
|
39 |
+
eps: float, a small number to ensure the numerical stability of the
|
40 |
+
denominator (default: 1e-6)
|
41 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
42 |
+
module for dispatching events (default: the default
|
43 |
+
global dispatcher)
|
44 |
+
"""
|
45 |
+
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
|
46 |
+
event_dispatcher=""):
|
47 |
+
super(LinearAttention, self).__init__()
|
48 |
+
self.feature_map = (
|
49 |
+
feature_map(query_dimensions) if feature_map else
|
50 |
+
elu_feature_map(query_dimensions)
|
51 |
+
)
|
52 |
+
self.eps = eps
|
53 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
54 |
+
|
55 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
56 |
+
key_lengths):
|
57 |
+
# Apply the feature map to the queries and keys
|
58 |
+
self.feature_map.new_feature_map(queries.device)
|
59 |
+
Q = self.feature_map.forward_queries(queries)
|
60 |
+
K = self.feature_map.forward_keys(keys)
|
61 |
+
|
62 |
+
# Apply the key padding mask and make sure that the attn_mask is
|
63 |
+
# all_ones
|
64 |
+
if not attn_mask.all_ones:
|
65 |
+
raise RuntimeError(("LinearAttention does not support arbitrary "
|
66 |
+
"attention masks"))
|
67 |
+
K = K * key_lengths.float_matrix[:, :, None, None]
|
68 |
+
|
69 |
+
# Compute the KV matrix, namely the dot product of keys and values so
|
70 |
+
# that we never explicitly compute the attention matrix and thus
|
71 |
+
# decrease the complexity
|
72 |
+
KV = torch.einsum("nshd,nshm->nhmd", K, values)
|
73 |
+
|
74 |
+
# Compute the normalizer
|
75 |
+
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
|
76 |
+
|
77 |
+
# Finally compute and return the new values
|
78 |
+
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
|
79 |
+
|
80 |
+
return V.contiguous()
|
81 |
+
|
82 |
+
|
83 |
+
# Register the attention implementation so that it becomes available in our
|
84 |
+
# builders
|
85 |
+
AttentionRegistry.register(
|
86 |
+
"linear", LinearAttention,
|
87 |
+
[
|
88 |
+
("query_dimensions", Int),
|
89 |
+
("feature_map", Optional(Callable)),
|
90 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
91 |
+
]
|
92 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""Implement local context attention."""
|
7 |
+
|
8 |
+
from math import sqrt
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.nn import Module, Dropout
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
15 |
+
EventDispatcherInstance
|
16 |
+
from ..events import EventDispatcher
|
17 |
+
from ..local_product import local_dot_product, local_weighted_average
|
18 |
+
|
19 |
+
|
20 |
+
class LocalAttention(Module):
|
21 |
+
"""Implement fast local attention where a query can only attend to
|
22 |
+
neighboring keys.
|
23 |
+
|
24 |
+
In this attention module the query Q_i can only attend to a key K_j if
|
25 |
+
|i-j| < local_context/2.
|
26 |
+
|
27 |
+
Arguments
|
28 |
+
---------
|
29 |
+
local_context: The neighborhood to consider for local attention.
|
30 |
+
softmax_temp: The temperature to use for the softmax attention.
|
31 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
32 |
+
runtime)
|
33 |
+
attention_dropout: The dropout rate to apply to the attention
|
34 |
+
(default: 0.1)
|
35 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
36 |
+
module for dispatching events (default: the default
|
37 |
+
global dispatcher)
|
38 |
+
"""
|
39 |
+
def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1,
|
40 |
+
event_dispatcher=""):
|
41 |
+
super(LocalAttention, self).__init__()
|
42 |
+
self.local_context = local_context
|
43 |
+
self.softmax_temp = softmax_temp
|
44 |
+
self.dropout = Dropout(attention_dropout)
|
45 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
46 |
+
|
47 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
48 |
+
key_lengths):
|
49 |
+
"""Implements the local attention.
|
50 |
+
|
51 |
+
The attn_mask can be anything but the only values that will be
|
52 |
+
considered will be the ones in the neighborhood of each query.
|
53 |
+
|
54 |
+
Arguments
|
55 |
+
---------
|
56 |
+
queries: (N, L, H, E) The tensor containing the queries
|
57 |
+
keys: (N, S, H, E) The tensor containing the keys
|
58 |
+
values: (N, S, H, D) The tensor containing the values
|
59 |
+
attn_mask: An implementation of BaseMask that encodes where each
|
60 |
+
query can attend to
|
61 |
+
query_lengths: An implementation of BaseMask that encodes how
|
62 |
+
many queries each sequence in the batch consists of
|
63 |
+
key_lengths: An implementation of BaseMask that encodes how
|
64 |
+
many queries each sequence in the batch consists of
|
65 |
+
"""
|
66 |
+
# Extract some shapes and compute the temperature
|
67 |
+
N, L, H, E = queries.shape
|
68 |
+
_, S, _, D = values.shape
|
69 |
+
context = self.local_context
|
70 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
71 |
+
|
72 |
+
# Permute the dimensions to NHLE instead of NLHE
|
73 |
+
queries = queries.permute(0, 2, 1, 3).contiguous()
|
74 |
+
keys = keys.permute(0, 2, 1, 3).contiguous()
|
75 |
+
values = values.permute(0, 2, 1, 3).contiguous()
|
76 |
+
|
77 |
+
QK = local_dot_product(
|
78 |
+
queries,
|
79 |
+
keys,
|
80 |
+
attn_mask.additive_matrix_finite,
|
81 |
+
key_lengths.lengths,
|
82 |
+
self.local_context
|
83 |
+
)
|
84 |
+
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
|
85 |
+
|
86 |
+
V_new = local_weighted_average(A, values)
|
87 |
+
|
88 |
+
return V_new.permute(0, 2, 1, 3).contiguous()
|
89 |
+
|
90 |
+
|
91 |
+
# Register the attention implementation so that it becomes available in our
|
92 |
+
# builders
|
93 |
+
AttentionRegistry.register(
|
94 |
+
"local", LocalAttention,
|
95 |
+
[
|
96 |
+
("local_context", Int),
|
97 |
+
("softmax_temp", Optional(Float)),
|
98 |
+
("attention_dropout", Optional(Float, 0.1)),
|
99 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
100 |
+
]
|
101 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Implement the Reformer attention from the paper
|
8 |
+
"Reformer the efficient transformer"."""
|
9 |
+
|
10 |
+
from math import sqrt
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.nn import Dropout, Module
|
14 |
+
from torch.nn.init import normal_
|
15 |
+
|
16 |
+
from ..attention_registry import AttentionRegistry, Optional, Int, Float, \
|
17 |
+
Bool, EventDispatcherInstance
|
18 |
+
from ..events import EventDispatcher
|
19 |
+
from ..masking import FullMask
|
20 |
+
|
21 |
+
|
22 |
+
class ReformerAttention(Module):
|
23 |
+
"""Implement the attention module of the paper "Reformer the efficient
|
24 |
+
transformer"
|
25 |
+
|
26 |
+
Arguments
|
27 |
+
---------
|
28 |
+
chunk_size : Chunk size for each block (default: 32)
|
29 |
+
bits : Number of bits for hashing (default: 8)
|
30 |
+
rounds : Number of rounds of attention computation (default: 4)
|
31 |
+
masked : If true, the query does not attend to itsself (default: False)
|
32 |
+
softmax_temp: The temperature to use for the softmax attention.
|
33 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
34 |
+
runtime)
|
35 |
+
attention_dropout: The dropout rate to apply to the attention
|
36 |
+
(default: 0.1)
|
37 |
+
event_dispatcher: str or EventDispatcher instance to be used by this
|
38 |
+
module for dispatching events (default: the default
|
39 |
+
global dispatcher)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, chunk_size=32, bits=8, rounds=4, masked=False,
|
43 |
+
softmax_temp=None, attention_dropout=0.1,
|
44 |
+
event_dispatcher=""):
|
45 |
+
super(ReformerAttention, self).__init__()
|
46 |
+
|
47 |
+
self.chunk_size = chunk_size
|
48 |
+
self.bits = bits
|
49 |
+
self.rounds = rounds
|
50 |
+
self.masked = masked
|
51 |
+
self.softmax_temp = softmax_temp
|
52 |
+
self.dropout = Dropout(attention_dropout)
|
53 |
+
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
|
54 |
+
|
55 |
+
def _normalize(self, x):
|
56 |
+
norms = torch.sqrt(torch.einsum("nlhe,nlhe->nlh", x, x))
|
57 |
+
x_normed = x / norms.unsqueeze(-1)
|
58 |
+
return x_normed
|
59 |
+
|
60 |
+
def _look_back(self, x):
|
61 |
+
xshape = x.shape
|
62 |
+
|
63 |
+
return torch.cat([
|
64 |
+
x.new_zeros((xshape[0], 1) + xshape[2:]),
|
65 |
+
torch.repeat_interleave(x, 2, dim=1)[:,:-1]
|
66 |
+
], dim=1).view(xshape[0], xshape[1], 2*xshape[2], *xshape[3:])
|
67 |
+
|
68 |
+
def _reformer_round(self, Q, K, V, mask, softmax_temp):
|
69 |
+
# Hash the queries
|
70 |
+
N, L, H, E = Q.shape
|
71 |
+
planes = Q.new_empty(self.bits, E)
|
72 |
+
normal_(planes)
|
73 |
+
projected = torch.einsum("nlhe,be->nlhb", K, planes)
|
74 |
+
hashes = torch.argmax(
|
75 |
+
torch.cat([projected, -projected], dim=-1),
|
76 |
+
dim=-1
|
77 |
+
)
|
78 |
+
|
79 |
+
# Sort the queries in order to group them
|
80 |
+
group = torch.argsort(hashes, dim=1)
|
81 |
+
|
82 |
+
invert_group = torch.empty_like(group)
|
83 |
+
batch_indices = torch.arange(N, device=hashes.device).view(N, 1, 1)
|
84 |
+
sequence_indices = torch.arange(L, device=hashes.device).view(1, L, 1)
|
85 |
+
head_indices = torch.arange(H, device=hashes.device).view(1, 1, H)
|
86 |
+
invert_group[batch_indices, group, head_indices] = sequence_indices
|
87 |
+
group = group.view(N, -1, self.chunk_size, H)
|
88 |
+
invert_group = invert_group.view(N, -1, self.chunk_size, H)
|
89 |
+
batch_indices = batch_indices.unsqueeze(1)
|
90 |
+
head_indices = head_indices.unsqueeze(0)
|
91 |
+
|
92 |
+
# Reorder Q, V and mask
|
93 |
+
Q_grouped = Q[batch_indices, group, head_indices]
|
94 |
+
K_grouped = K[batch_indices, group, head_indices]
|
95 |
+
V_grouped = V[batch_indices, group, head_indices]
|
96 |
+
mask_grouped = mask[
|
97 |
+
batch_indices.unsqueeze(1),
|
98 |
+
group.unsqueeze(3),
|
99 |
+
self._look_back(group).unsqueeze(2)
|
100 |
+
]
|
101 |
+
|
102 |
+
mask_grouped[:, 0, :, :Q_grouped.shape[2]] = float("-inf")
|
103 |
+
|
104 |
+
# When everything is masked just unmask everything because it doesn't
|
105 |
+
# matter what the output is at those positions
|
106 |
+
# This is to avoid inf/nans in the new values at masked positions
|
107 |
+
infmask = torch.isinf(mask_grouped)
|
108 |
+
infmask = torch.all(infmask, dim=3, keepdims=True)
|
109 |
+
mask_grouped = mask_grouped.masked_fill(infmask, 0.)
|
110 |
+
|
111 |
+
# Attention
|
112 |
+
K_grouped = self._look_back(K_grouped)
|
113 |
+
QQ = torch.einsum("nblhe,nbshe->nbhls", Q_grouped, K_grouped)
|
114 |
+
QQ = QQ + mask_grouped.permute(0, 1, 4, 2, 3)
|
115 |
+
A = torch.softmax(softmax_temp * QQ, dim=-1)
|
116 |
+
A = self.dropout(A)
|
117 |
+
|
118 |
+
# Values
|
119 |
+
V_grouped = self._look_back(V_grouped)
|
120 |
+
V_new = torch.einsum("nbhls,nbshe->nblhe", A, V_grouped)
|
121 |
+
V_new = V_new.contiguous().view(N, -1, H, E)
|
122 |
+
V_new = V_new[batch_indices, invert_group, head_indices]
|
123 |
+
V_new = V_new.contiguous().view(N, L, H, E)
|
124 |
+
return V_new
|
125 |
+
|
126 |
+
def forward(self, queries, keys, values, attn_mask, query_lengths,
|
127 |
+
key_lengths):
|
128 |
+
# Extract the dimensions of query, key, value
|
129 |
+
N, L, H, E = queries.shape
|
130 |
+
|
131 |
+
softmax_temp = self.softmax_temp or 1./sqrt(E)
|
132 |
+
# Create the mask
|
133 |
+
mask = key_lengths.additive_matrix.unsqueeze(1).expand(N, L, L)
|
134 |
+
if self.masked:
|
135 |
+
mask = mask + torch.eye(L, device=queries.device).unsqueeze(0)*float(-1e9)
|
136 |
+
|
137 |
+
if not attn_mask.all_ones:
|
138 |
+
mask = mask + attn_mask.additive_matrix.unsqueeze(0)
|
139 |
+
# Get normalized Queries as Keys
|
140 |
+
K = self._normalize(queries)
|
141 |
+
# Zero the masked out keys
|
142 |
+
K = K * key_lengths.float_matrix.view(N, L, 1, 1)
|
143 |
+
|
144 |
+
V_new = 0
|
145 |
+
factor = 1/self.rounds
|
146 |
+
for i in range(self.rounds):
|
147 |
+
V_new = V_new + \
|
148 |
+
factor * self._reformer_round(queries, K, values, mask, softmax_temp)
|
149 |
+
|
150 |
+
return V_new
|
151 |
+
|
152 |
+
|
153 |
+
# Register the attention implementation so that it becomes available in our
|
154 |
+
# builders
|
155 |
+
AttentionRegistry.register(
|
156 |
+
"reformer", ReformerAttention,
|
157 |
+
[
|
158 |
+
("chunk_size", Optional(Int, 32)),
|
159 |
+
("bits", Optional(Int, 63)),
|
160 |
+
("rounds", Optional(Int, 4)),
|
161 |
+
("masked", Optional(Bool, False)),
|
162 |
+
("softmax_temp", Optional(Float)),
|
163 |
+
("attention_dropout", Optional(Float, 0.1)),
|
164 |
+
("event_dispatcher", Optional(EventDispatcherInstance, ""))
|
165 |
+
]
|
166 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""Allow for the dynamic registration of new attention implementations.
|
7 |
+
|
8 |
+
This module provides a Registry implementation that other modules can use to
|
9 |
+
register attention implementations for the builders.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from .registry import \
|
13 |
+
AttentionRegistry, \
|
14 |
+
RecurrentAttentionRegistry, \
|
15 |
+
RecurrentCrossAttentionRegistry
|
16 |
+
from .spec import Spec, Choice, Optional, Int, Float, Bool, Callable, \
|
17 |
+
EventDispatcherInstance
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (786 Bytes). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc
ADDED
Binary file (2.27 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc
ADDED
Binary file (4.73 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
|
7 |
+
class Registry(object):
|
8 |
+
"""Hold the available attention implementations and their required
|
9 |
+
parameters."""
|
10 |
+
def __init__(self):
|
11 |
+
self._classes = {}
|
12 |
+
self._class_params = {}
|
13 |
+
self._parameters = {}
|
14 |
+
|
15 |
+
def register(self, key, class_object, parameter_tuples):
|
16 |
+
# register the class if the key is new
|
17 |
+
if key in self._classes:
|
18 |
+
raise ValueError("{} is already registered".format(key))
|
19 |
+
self._classes[key] = class_object
|
20 |
+
|
21 |
+
# register the parameters
|
22 |
+
for parameter, spec in parameter_tuples:
|
23 |
+
if (
|
24 |
+
parameter in self._parameters and
|
25 |
+
self._parameters[parameter] != spec
|
26 |
+
):
|
27 |
+
raise ValueError(("{} is already registered with "
|
28 |
+
"spec {!r} instead of {!r}").format(
|
29 |
+
parameter,
|
30 |
+
self._parameters[parameter],
|
31 |
+
spec
|
32 |
+
))
|
33 |
+
self._parameters[parameter] = spec
|
34 |
+
|
35 |
+
# note which parameters are needed by this class
|
36 |
+
self._class_params[key] = [p for p, s in parameter_tuples]
|
37 |
+
|
38 |
+
def __contains__(self, key):
|
39 |
+
return key in self._classes
|
40 |
+
|
41 |
+
def __getitem__(self, key):
|
42 |
+
return self._classes[key], self._class_params[key]
|
43 |
+
|
44 |
+
@property
|
45 |
+
def keys(self):
|
46 |
+
return list(self._classes.keys())
|
47 |
+
|
48 |
+
def contains_parameter(self, key):
|
49 |
+
return key in self._parameters
|
50 |
+
|
51 |
+
def validate_parameter(self, key, value):
|
52 |
+
try:
|
53 |
+
return self._parameters[key].get(value)
|
54 |
+
except Exception as e:
|
55 |
+
raise ValueError(("Invalid value {!r} for "
|
56 |
+
"parameter {!r}").format(value, key)) from e
|
57 |
+
|
58 |
+
|
59 |
+
AttentionRegistry = Registry()
|
60 |
+
RecurrentAttentionRegistry = Registry()
|
61 |
+
RecurrentCrossAttentionRegistry = Registry()
|
smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""Spec instances allow to describe and check the type and value of
|
7 |
+
parameters."""
|
8 |
+
|
9 |
+
from ..events import EventDispatcher
|
10 |
+
|
11 |
+
|
12 |
+
class Spec(object):
|
13 |
+
"""Describe and validate a parameter type.
|
14 |
+
|
15 |
+
Arguments
|
16 |
+
---------
|
17 |
+
predicate: A callable that checks if the value is acceptable and
|
18 |
+
returns its canonical value or raises ValueError.
|
19 |
+
name: A name to create a human readable description of the Spec
|
20 |
+
"""
|
21 |
+
def __init__(self, predicate, name="CustomSpec"):
|
22 |
+
self._predicate = predicate
|
23 |
+
self._name = name
|
24 |
+
|
25 |
+
def __repr__(self):
|
26 |
+
return self._name
|
27 |
+
|
28 |
+
def check(self, x):
|
29 |
+
try:
|
30 |
+
self._predicate(x)
|
31 |
+
return True
|
32 |
+
except ValueError:
|
33 |
+
return False
|
34 |
+
|
35 |
+
def get(self, x):
|
36 |
+
return self._predicate(x)
|
37 |
+
|
38 |
+
def __eq__(self, y):
|
39 |
+
return self is y
|
40 |
+
|
41 |
+
|
42 |
+
class Choice(Spec):
|
43 |
+
"""A parameter type for a set of options.
|
44 |
+
|
45 |
+
Arguments
|
46 |
+
---------
|
47 |
+
choices: A set or list of possible values for this parameter
|
48 |
+
"""
|
49 |
+
def __init__(self, choices):
|
50 |
+
self._choices = choices
|
51 |
+
|
52 |
+
def get(self, x):
|
53 |
+
if x in self._choices:
|
54 |
+
return x
|
55 |
+
raise ValueError("{!r} is not in {!r}".format(x, self._choices))
|
56 |
+
|
57 |
+
def __repr__(self):
|
58 |
+
return "Choice({!r})".format(self._choices)
|
59 |
+
|
60 |
+
def __eq__(self, x):
|
61 |
+
if isinstance(x, Choice):
|
62 |
+
return self._choices == x._choices
|
63 |
+
return False
|
64 |
+
|
65 |
+
|
66 |
+
class _Callable(Spec):
|
67 |
+
def __init__(self):
|
68 |
+
super(_Callable, self).__init__(None, "Callable")
|
69 |
+
|
70 |
+
def get(self, x):
|
71 |
+
if callable(x):
|
72 |
+
return x
|
73 |
+
raise ValueError("{!r} is not a callable".format(x))
|
74 |
+
|
75 |
+
|
76 |
+
class _EventDispatcherInstance(Spec):
|
77 |
+
def __init__(self):
|
78 |
+
super(_EventDispatcherInstance, self).__init__(
|
79 |
+
_EventDispatcherInstance._get_event_dispatcher,
|
80 |
+
"EventDispatcherInstance"
|
81 |
+
)
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def _get_event_dispatcher(x):
|
85 |
+
if isinstance(x, str):
|
86 |
+
return x
|
87 |
+
if isinstance(x, EventDispatcher):
|
88 |
+
return x
|
89 |
+
raise ValueError("{!r} is not an event dispatcher".format(x))
|
90 |
+
|
91 |
+
|
92 |
+
class Optional(Spec):
|
93 |
+
"""Represent an optional parameter that can either have a value or it can
|
94 |
+
be None.
|
95 |
+
|
96 |
+
Arguments
|
97 |
+
---------
|
98 |
+
spec: The spec for the value if it is not None
|
99 |
+
default: The returned value in case it is None
|
100 |
+
"""
|
101 |
+
def __init__(self, spec, default=None):
|
102 |
+
self._other_spec = spec
|
103 |
+
self._default = default
|
104 |
+
|
105 |
+
def __repr__(self):
|
106 |
+
return "Optional[{!r}, {!r}]".format(self._other_spec, self._default)
|
107 |
+
|
108 |
+
def get(self, x):
|
109 |
+
if x is None:
|
110 |
+
return self._default
|
111 |
+
return self._other_spec.get(x)
|
112 |
+
|
113 |
+
def __eq__(self, x):
|
114 |
+
if isinstance(x, Optional):
|
115 |
+
return (
|
116 |
+
self._other_spec == x._other_spec and
|
117 |
+
self._default == x._default
|
118 |
+
)
|
119 |
+
return False
|
120 |
+
|
121 |
+
|
122 |
+
Int = Spec(int, "Int")
|
123 |
+
Float = Spec(float, "Float")
|
124 |
+
Bool = Spec(bool, "Bool")
|
125 |
+
Callable = _Callable()
|
126 |
+
EventDispatcherInstance = _EventDispatcherInstance()
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""This module implements builders that simplify building complex transformer
|
8 |
+
architectures with different attention mechanisms.
|
9 |
+
|
10 |
+
The main idea is to facilitate the construction of various attention layers and
|
11 |
+
transformer encoder layers and simplify their assembly into one transformer
|
12 |
+
module. It also allows for flexibility in the scripts as many builder
|
13 |
+
parameters can correspond 1-1 with command line arguments.
|
14 |
+
|
15 |
+
Example usage:
|
16 |
+
|
17 |
+
builder = TransformerEncoderBuilder()
|
18 |
+
builder.n_layers = 12
|
19 |
+
builder.n_heads = 8
|
20 |
+
builder.feed_forward_dimensions = 1024
|
21 |
+
builder.query_dimensions = 64
|
22 |
+
builder.value_dimensions = 64
|
23 |
+
builder.dropout = 0.1
|
24 |
+
builder.attention_dropout = 0.1
|
25 |
+
builder.attention_type = "linear"
|
26 |
+
transformer = builder.get()
|
27 |
+
"""
|
28 |
+
|
29 |
+
__all__ = [
|
30 |
+
"AttentionBuilder",
|
31 |
+
"RecurrentAttentionBuilder",
|
32 |
+
"RecurrentCrossAttentionBuilder"
|
33 |
+
]
|
34 |
+
|
35 |
+
# Import the attention implementations so that they register themselves with
|
36 |
+
# the builder. Attention implementations external to the library should be
|
37 |
+
# imported before using the builders.
|
38 |
+
#
|
39 |
+
# TODO: Should this behaviour change? Namely, should all attention
|
40 |
+
# implementations be imported in order to be useable? This also allows
|
41 |
+
# using the library even partially built, for instance.
|
42 |
+
from ..attention import \
|
43 |
+
FullAttention, \
|
44 |
+
LinearAttention
|
45 |
+
|
46 |
+
del FullAttention, \
|
47 |
+
LinearAttention
|
48 |
+
|
49 |
+
|
50 |
+
from .attention_builders import \
|
51 |
+
AttentionBuilder, \
|
52 |
+
RecurrentAttentionBuilder, \
|
53 |
+
RecurrentCrossAttentionBuilder
|
54 |
+
|
55 |
+
from .transformer_builders import \
|
56 |
+
TransformerEncoderBuilder, \
|
57 |
+
RecurrentEncoderBuilder, \
|
58 |
+
TransformerDecoderBuilder, \
|
59 |
+
RecurrentDecoderBuilder
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.46 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc
ADDED
Binary file (6.49 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc
ADDED
Binary file (2.3 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc
ADDED
Binary file (18.7 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
from .base import BaseBuilder
|
9 |
+
from ..attention_registry import \
|
10 |
+
AttentionRegistry, \
|
11 |
+
RecurrentAttentionRegistry, \
|
12 |
+
RecurrentCrossAttentionRegistry
|
13 |
+
|
14 |
+
|
15 |
+
class BaseAttentionBuilder(BaseBuilder):
|
16 |
+
def __init__(self, registry):
|
17 |
+
self._registry = registry
|
18 |
+
self._parameters = defaultdict(lambda: None)
|
19 |
+
|
20 |
+
@property
|
21 |
+
def available_attentions(self):
|
22 |
+
"""Return a list with the available attention implementations."""
|
23 |
+
return self._registry.keys
|
24 |
+
|
25 |
+
def validate_attention_type(self, attention_type):
|
26 |
+
"""Parse the attention type according to the rules used by `get()` and
|
27 |
+
check if the requested attention is constructible."""
|
28 |
+
return all(
|
29 |
+
all(t in self._registry for t in a.split(","))
|
30 |
+
for a in attention_type.split(":")
|
31 |
+
)
|
32 |
+
|
33 |
+
def __setattr__(self, key, value):
|
34 |
+
# Make sure we have normal behaviour for the class members _registry
|
35 |
+
# and _parameters
|
36 |
+
if key in ["_registry", "_parameters"]:
|
37 |
+
return object.__setattr__(self, key, value)
|
38 |
+
|
39 |
+
# Assign everything else in the parameters dictionary
|
40 |
+
if not self._registry.contains_parameter(key):
|
41 |
+
raise AttributeError(("{!r} is not a valid attention "
|
42 |
+
"parameter name").format(key))
|
43 |
+
self._parameters[key] = self._registry.validate_parameter(key, value)
|
44 |
+
|
45 |
+
def __getattr__(self, key):
|
46 |
+
if key in self._parameters:
|
47 |
+
return self._parameters[key]
|
48 |
+
else:
|
49 |
+
raise AttributeError()
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
return (
|
53 |
+
"{}.from_kwargs(\n".format(self.__class__.__name__) +
|
54 |
+
"\n".join([" {}={!r},".format(k, v)
|
55 |
+
for k, v in self._parameters.items()])[:-1] +
|
56 |
+
"\n)"
|
57 |
+
)
|
58 |
+
|
59 |
+
def get(self, attention_type):
|
60 |
+
"""Construct the attention implementation object and return it.
|
61 |
+
|
62 |
+
The passed in attention_type argument defines the attention to be
|
63 |
+
created. It should be a string and in its simplest form it should
|
64 |
+
be one of the available choices from `available_attentions`.
|
65 |
+
|
66 |
+
However, to enable attention decoration, namely an attention
|
67 |
+
implementation augmenting the functionality of another implementation,
|
68 |
+
the attention type can be a colon separated list of compositions like
|
69 |
+
the following examples:
|
70 |
+
|
71 |
+
- 'att1' means instantiate att1
|
72 |
+
- 'att2:att1' means instantiate att1 and decorate it with att2
|
73 |
+
- 'att3:att1,att4' means instantiate att1 and att4 and decorate
|
74 |
+
them with att3
|
75 |
+
|
76 |
+
Arguments
|
77 |
+
---------
|
78 |
+
attention_type: A string that contains one or more keys from
|
79 |
+
`available_attentions` separated with a colon to
|
80 |
+
denote the decoration pattern.
|
81 |
+
"""
|
82 |
+
compositions = reversed(attention_type.split(":"))
|
83 |
+
attentions = []
|
84 |
+
for c in compositions:
|
85 |
+
attentions = [
|
86 |
+
self._construct_attention(t, attentions)
|
87 |
+
for t in c.split(",")
|
88 |
+
]
|
89 |
+
if len(attentions) > 1:
|
90 |
+
raise ValueError(("Invalid attention_type argument "
|
91 |
+
"{!r}").format(attention_type))
|
92 |
+
return attentions[0]
|
93 |
+
|
94 |
+
def _construct_attention(self, attention_type, decorated=[]):
|
95 |
+
"""Construct an attention implementation object.
|
96 |
+
|
97 |
+
Arguments
|
98 |
+
---------
|
99 |
+
attention_type: A string that contains a single key from the
|
100 |
+
`available_attentions`
|
101 |
+
decorated: A list of attention implementations to pass as arguments
|
102 |
+
to be decorated
|
103 |
+
"""
|
104 |
+
if attention_type not in self._registry:
|
105 |
+
raise ValueError(("Unknown attention type "
|
106 |
+
"{!r}").format(attention_type))
|
107 |
+
|
108 |
+
attention, parameters = self._registry[attention_type]
|
109 |
+
parameter_dictionary = {
|
110 |
+
p: self._registry.validate_parameter(p, self._parameters[p])
|
111 |
+
for p in parameters
|
112 |
+
}
|
113 |
+
|
114 |
+
return attention(*decorated, **parameter_dictionary)
|
115 |
+
|
116 |
+
|
117 |
+
class AttentionBuilder(BaseAttentionBuilder):
|
118 |
+
"""Build attention implementations for batch sequence processing or
|
119 |
+
training."""
|
120 |
+
def __init__(self):
|
121 |
+
super(AttentionBuilder, self).__init__(AttentionRegistry)
|
122 |
+
|
123 |
+
|
124 |
+
class RecurrentAttentionBuilder(BaseAttentionBuilder):
|
125 |
+
"""Build attention implementations for autoregressive sequence
|
126 |
+
processing."""
|
127 |
+
def __init__(self):
|
128 |
+
super(RecurrentAttentionBuilder, self).__init__(
|
129 |
+
RecurrentAttentionRegistry
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class RecurrentCrossAttentionBuilder(BaseAttentionBuilder):
|
134 |
+
"""Build attention implementations for autoregressive cross attention
|
135 |
+
computation."""
|
136 |
+
def __init__(self):
|
137 |
+
super(RecurrentCrossAttentionBuilder, self).__init__(
|
138 |
+
RecurrentCrossAttentionRegistry
|
139 |
+
)
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
"""Provide a class for the others to inherit some useful functionality."""
|
8 |
+
|
9 |
+
|
10 |
+
class BaseBuilder(object):
|
11 |
+
@classmethod
|
12 |
+
def from_kwargs(cls, **kwargs):
|
13 |
+
"""Construct a builder and set all the keyword arguments as parameters.
|
14 |
+
|
15 |
+
The keyword argument strict is passed to
|
16 |
+
BaseBuilder.from_dictionary separately.
|
17 |
+
|
18 |
+
See BaseBuilder.from_dictionary().
|
19 |
+
"""
|
20 |
+
strict = kwargs.pop("strict", True)
|
21 |
+
return cls.from_dictionary(kwargs, strict=strict)
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def from_namespace(cls, args, strict=False):
|
25 |
+
"""Construct a builder from an argparse Namespace.
|
26 |
+
|
27 |
+
To be used for building transformers from command line arguments.
|
28 |
+
|
29 |
+
See BaseBuilder.from_dictionary().
|
30 |
+
"""
|
31 |
+
return cls.from_dictionary(vars(args), strict=strict)
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_dictionary(cls, dictionary, strict=True):
|
35 |
+
"""Construct a builder and set all the parameters in the dictionary.
|
36 |
+
|
37 |
+
Given a dictionary
|
38 |
+
|
39 |
+
d = {"foo": "bar"}
|
40 |
+
|
41 |
+
then
|
42 |
+
|
43 |
+
builder = TransformerEncoderBuilder.from_dictionary(d)
|
44 |
+
|
45 |
+
is equivalent to
|
46 |
+
|
47 |
+
builder = TransformerEncoderBuilder()
|
48 |
+
builder.foo = "bar"
|
49 |
+
|
50 |
+
Arguments
|
51 |
+
---------
|
52 |
+
dictionary: A dictionary of parameters to set to the builder.
|
53 |
+
strict: bool, If a key is not a parameter and strict is set to True
|
54 |
+
then a ValueError is raised, otherwise that dictionary key
|
55 |
+
is ignored (default: True)
|
56 |
+
"""
|
57 |
+
builder = cls()
|
58 |
+
for k, v in dictionary.items():
|
59 |
+
try:
|
60 |
+
setattr(builder, k, v)
|
61 |
+
except AttributeError:
|
62 |
+
if strict:
|
63 |
+
raise ValueError(("The builder has no "
|
64 |
+
"parameter {!r}").format(k))
|
65 |
+
else:
|
66 |
+
continue
|
67 |
+
return builder
|
smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""Build complex transformer architectures for inference or training easily."""
|
7 |
+
|
8 |
+
from torch.nn import LayerNorm
|
9 |
+
|
10 |
+
from ..attention import AttentionLayer
|
11 |
+
from ..transformers import TransformerEncoder, TransformerEncoderLayer, \
|
12 |
+
TransformerDecoder, TransformerDecoderLayer
|
13 |
+
from ..recurrent.attention import \
|
14 |
+
RecurrentAttentionLayer, \
|
15 |
+
RecurrentCrossAttentionLayer
|
16 |
+
from ..recurrent.transformers import \
|
17 |
+
RecurrentTransformerEncoder, RecurrentTransformerEncoderLayer, \
|
18 |
+
RecurrentTransformerDecoder, RecurrentTransformerDecoderLayer
|
19 |
+
from .base import BaseBuilder
|
20 |
+
from .attention_builders import AttentionBuilder, RecurrentAttentionBuilder, \
|
21 |
+
RecurrentCrossAttentionBuilder
|
22 |
+
|
23 |
+
|
24 |
+
class BaseTransformerBuilder(BaseBuilder):
|
25 |
+
"""Contains all the parameters for building a transformer other than the
|
26 |
+
attention part.
|
27 |
+
|
28 |
+
Classes extending the BaseTransformerBuilder should implement the `get()`
|
29 |
+
method that actually builds the transformer.
|
30 |
+
"""
|
31 |
+
def __init__(self):
|
32 |
+
# transformer parameters
|
33 |
+
self._n_layers = 4
|
34 |
+
self._n_heads = 4
|
35 |
+
self._d_query = 64
|
36 |
+
self._d_value = 64
|
37 |
+
self._d_ff = 1024
|
38 |
+
self._dropout = 0.1
|
39 |
+
self._activation = "relu"
|
40 |
+
self._final_norm = True
|
41 |
+
self._event_dispatcher = "" # the default global dispatcher
|
42 |
+
|
43 |
+
@property
|
44 |
+
def n_layers(self):
|
45 |
+
"""The number of transformer layers."""
|
46 |
+
return self._n_layers
|
47 |
+
|
48 |
+
@n_layers.setter
|
49 |
+
def n_layers(self, val):
|
50 |
+
self._n_layers = val
|
51 |
+
|
52 |
+
@property
|
53 |
+
def n_heads(self):
|
54 |
+
"""The number of heads in each transformer layer."""
|
55 |
+
return self._n_heads
|
56 |
+
|
57 |
+
@n_heads.setter
|
58 |
+
def n_heads(self, val):
|
59 |
+
self._n_heads = val
|
60 |
+
|
61 |
+
@property
|
62 |
+
def feed_forward_dimensions(self):
|
63 |
+
"""The dimensions of the fully connected layer in the transformer
|
64 |
+
layers."""
|
65 |
+
return self._d_ff
|
66 |
+
|
67 |
+
@feed_forward_dimensions.setter
|
68 |
+
def feed_forward_dimensions(self, val):
|
69 |
+
self._d_ff = val
|
70 |
+
|
71 |
+
@property
|
72 |
+
def query_dimensions(self):
|
73 |
+
"""The dimensions of the queries and keys in each attention layer."""
|
74 |
+
return self._d_query
|
75 |
+
|
76 |
+
@query_dimensions.setter
|
77 |
+
def query_dimensions(self, val):
|
78 |
+
self._d_query = val
|
79 |
+
|
80 |
+
@property
|
81 |
+
def value_dimensions(self):
|
82 |
+
"""The dimensions of the values in each attention layer."""
|
83 |
+
return self._d_value
|
84 |
+
|
85 |
+
@value_dimensions.setter
|
86 |
+
def value_dimensions(self, val):
|
87 |
+
self._d_value = val
|
88 |
+
|
89 |
+
@property
|
90 |
+
def dropout(self):
|
91 |
+
"""The dropout rate to be applied in the transformer encoder layer."""
|
92 |
+
return self._dropout
|
93 |
+
|
94 |
+
@dropout.setter
|
95 |
+
def dropout(self, val):
|
96 |
+
self._dropout = val
|
97 |
+
|
98 |
+
@property
|
99 |
+
def activation(self):
|
100 |
+
"""The activation function for the transformer layer.
|
101 |
+
|
102 |
+
One of {'relu', 'gelu'}.
|
103 |
+
"""
|
104 |
+
return self._activation
|
105 |
+
|
106 |
+
@activation.setter
|
107 |
+
def activation(self, val):
|
108 |
+
activations = ["relu", "gelu"]
|
109 |
+
if val not in activations:
|
110 |
+
raise ValueError(("{!r} is not one of the availabel activation "
|
111 |
+
"types {!r}").format(val, activations))
|
112 |
+
self._activation = val
|
113 |
+
|
114 |
+
@property
|
115 |
+
def final_normalization(self):
|
116 |
+
"""Whether to add LayerNorm as the final layer of the
|
117 |
+
TransformerEncoder."""
|
118 |
+
return self._final_norm
|
119 |
+
|
120 |
+
@final_normalization.setter
|
121 |
+
def final_normalization(self, val):
|
122 |
+
self._final_norm = bool(val)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def event_dispatcher(self):
|
126 |
+
"""The transformer event dispatcher either as a string or as an
|
127 |
+
EventDispatcher object."""
|
128 |
+
return self._event_dispatcher
|
129 |
+
|
130 |
+
@event_dispatcher.setter
|
131 |
+
def event_dispatcher(self, event_dispatcher):
|
132 |
+
self._event_dispatcher = event_dispatcher
|
133 |
+
|
134 |
+
def get(self):
|
135 |
+
"""Build the transformer and return it."""
|
136 |
+
raise NotImplementedError()
|
137 |
+
|
138 |
+
|
139 |
+
class BaseTransformerEncoderBuilder(BaseTransformerBuilder):
|
140 |
+
"""Implement the logic of building a transformer encoder but leave the
|
141 |
+
specific layers open for changing by the inheriting classes. This allows us
|
142 |
+
to reuse the logic for creating both the TransformerEncoder and the
|
143 |
+
RecurrentTransformerEncoder.
|
144 |
+
|
145 |
+
Inheriting classes should implement the following:
|
146 |
+
|
147 |
+
- _get_attention_builder()
|
148 |
+
- _get_attention_layer_class()
|
149 |
+
- _get_encoder_class()
|
150 |
+
- _get_encoder_layer_class()
|
151 |
+
"""
|
152 |
+
def __init__(self):
|
153 |
+
super(BaseTransformerEncoderBuilder, self).__init__()
|
154 |
+
self._attention_builder = self._get_attention_builder()
|
155 |
+
self._attention_type = "full"
|
156 |
+
|
157 |
+
def _get_attention_builder(self):
|
158 |
+
"""Return an instance of the appropriate attention builder."""
|
159 |
+
raise NotImplementedError()
|
160 |
+
|
161 |
+
def _get_attention_layer_class(self):
|
162 |
+
"""Return the class for the layer that projects queries keys and
|
163 |
+
values."""
|
164 |
+
raise NotImplementedError()
|
165 |
+
|
166 |
+
def _get_encoder_class(self):
|
167 |
+
"""Return the class for the transformer encoder."""
|
168 |
+
raise NotImplementedError()
|
169 |
+
|
170 |
+
def _get_encoder_layer_class(self):
|
171 |
+
"""Return the class for the transformer encoder layer."""
|
172 |
+
raise NotImplementedError()
|
173 |
+
|
174 |
+
@property
|
175 |
+
def attention(self):
|
176 |
+
"""The attention builder instance."""
|
177 |
+
return self._attention_builder
|
178 |
+
|
179 |
+
@property
|
180 |
+
def attention_type(self):
|
181 |
+
"""The attention implementation chosen."""
|
182 |
+
return self._attention_type
|
183 |
+
|
184 |
+
@attention_type.setter
|
185 |
+
def attention_type(self, val):
|
186 |
+
if not self._attention_builder.validate_attention_type(val):
|
187 |
+
raise ValueError(("{!r} is not an available attention "
|
188 |
+
"type").format(val))
|
189 |
+
self._attention_type = val
|
190 |
+
|
191 |
+
def __setattr__(self, key, val):
|
192 |
+
# "protected" attributes are settable (probably from withing the class)
|
193 |
+
if key[0] == "_":
|
194 |
+
return super().__setattr__(key, val)
|
195 |
+
|
196 |
+
# Existing attributes are settable but they might also be attention
|
197 |
+
# parameters so try that as well
|
198 |
+
fail_on_exception = True
|
199 |
+
if hasattr(self, key):
|
200 |
+
super().__setattr__(key, val)
|
201 |
+
fail_on_exception = False
|
202 |
+
|
203 |
+
# Non-existing "public" attributes may be attention parameters
|
204 |
+
try:
|
205 |
+
setattr(self._attention_builder, key, val)
|
206 |
+
except:
|
207 |
+
if fail_on_exception:
|
208 |
+
raise
|
209 |
+
|
210 |
+
def get(self):
|
211 |
+
"""Build the transformer and return it."""
|
212 |
+
# Set the event dispatcher to the attention builder
|
213 |
+
self.attention.event_dispatcher = self.event_dispatcher
|
214 |
+
|
215 |
+
# Extract into local variables the classes to be used
|
216 |
+
Encoder = self._get_encoder_class()
|
217 |
+
EncoderLayer = self._get_encoder_layer_class()
|
218 |
+
Attention = self._get_attention_layer_class()
|
219 |
+
|
220 |
+
model_dimensions = self.value_dimensions*self.n_heads
|
221 |
+
return Encoder(
|
222 |
+
[
|
223 |
+
EncoderLayer(
|
224 |
+
Attention(
|
225 |
+
self.attention.get(self.attention_type),
|
226 |
+
model_dimensions,
|
227 |
+
self.n_heads,
|
228 |
+
d_keys=self.query_dimensions,
|
229 |
+
d_values=self.value_dimensions,
|
230 |
+
event_dispatcher=self.event_dispatcher
|
231 |
+
),
|
232 |
+
model_dimensions,
|
233 |
+
self.feed_forward_dimensions,
|
234 |
+
self.dropout,
|
235 |
+
self.activation,
|
236 |
+
event_dispatcher=self.event_dispatcher
|
237 |
+
)
|
238 |
+
for _ in range(self.n_layers)
|
239 |
+
],
|
240 |
+
(LayerNorm(model_dimensions) if self.final_normalization else None),
|
241 |
+
event_dispatcher=self.event_dispatcher
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
class TransformerEncoderBuilder(BaseTransformerEncoderBuilder):
|
246 |
+
"""Build a batch transformer encoder for training or processing of
|
247 |
+
sequences all elements at a time.
|
248 |
+
|
249 |
+
Example usage:
|
250 |
+
|
251 |
+
builder = TransformerEncoderBuilder()
|
252 |
+
builder.n_layers = 12
|
253 |
+
builder.n_heads = 8
|
254 |
+
builder.feed_forward_dimensions = 1024
|
255 |
+
builder.query_dimensions = 64
|
256 |
+
builder.value_dimensions = 64
|
257 |
+
builder.dropout = 0.1
|
258 |
+
builder.attention_dropout = 0.1
|
259 |
+
builder.attention_type = "linear"
|
260 |
+
transformer = builder.get()
|
261 |
+
"""
|
262 |
+
def _get_attention_builder(self):
|
263 |
+
"""Return an instance of the appropriate attention builder."""
|
264 |
+
return AttentionBuilder()
|
265 |
+
|
266 |
+
def _get_attention_layer_class(self):
|
267 |
+
"""Return the class for the layer that projects queries keys and
|
268 |
+
values."""
|
269 |
+
return AttentionLayer
|
270 |
+
|
271 |
+
def _get_encoder_class(self):
|
272 |
+
"""Return the class for the transformer encoder."""
|
273 |
+
return TransformerEncoder
|
274 |
+
|
275 |
+
def _get_encoder_layer_class(self):
|
276 |
+
"""Return the class for the transformer encoder layer."""
|
277 |
+
return TransformerEncoderLayer
|
278 |
+
|
279 |
+
|
280 |
+
class RecurrentEncoderBuilder(BaseTransformerEncoderBuilder):
|
281 |
+
"""Build a transformer encoder for autoregressive processing of sequences.
|
282 |
+
|
283 |
+
Example usage:
|
284 |
+
|
285 |
+
builder = RecurrentEncoderBuilder()
|
286 |
+
builder.n_layers = 12
|
287 |
+
builder.n_heads = 8
|
288 |
+
builder.feed_forward_dimensions = 1024
|
289 |
+
builder.query_dimensions = 64
|
290 |
+
builder.value_dimensions = 64
|
291 |
+
builder.dropout = 0.1
|
292 |
+
builder.attention_dropout = 0.1
|
293 |
+
builder.attention_type = "linear"
|
294 |
+
transformer = builder.get()
|
295 |
+
"""
|
296 |
+
def _get_attention_builder(self):
|
297 |
+
"""Return an attention builder for recurrent attention."""
|
298 |
+
return RecurrentAttentionBuilder()
|
299 |
+
|
300 |
+
def _get_attention_layer_class(self):
|
301 |
+
"""Return the class for the recurrent layer that projects queries keys
|
302 |
+
and values."""
|
303 |
+
return RecurrentAttentionLayer
|
304 |
+
|
305 |
+
def _get_encoder_class(self):
|
306 |
+
"""Return the class for the recurrent transformer encoder."""
|
307 |
+
return RecurrentTransformerEncoder
|
308 |
+
|
309 |
+
def _get_encoder_layer_class(self):
|
310 |
+
"""Return the class for the recurrent transformer encoder layer."""
|
311 |
+
return RecurrentTransformerEncoderLayer
|
312 |
+
|
313 |
+
|
314 |
+
class BaseTransformerDecoderBuilder(BaseTransformerBuilder):
|
315 |
+
"""Similar to BaseTransformerEncoderBuilder implement the logic of
|
316 |
+
building the transformer decoder without defining concrete layers.
|
317 |
+
|
318 |
+
Inheriting classes should implement the following:
|
319 |
+
|
320 |
+
- _get_self_attention_builder() and _get_cross_attention_builder()
|
321 |
+
- _get_self_attention_layer_class() and _get_cross_attention_layer_class()
|
322 |
+
- _get_decoder_class()
|
323 |
+
- _get_decoder_layer_class()
|
324 |
+
"""
|
325 |
+
def __init__(self):
|
326 |
+
super(BaseTransformerDecoderBuilder, self).__init__()
|
327 |
+
self._self_attention_builder = self._get_self_attention_builder()
|
328 |
+
self._cross_attention_builder = self._get_cross_attention_builder()
|
329 |
+
self._self_attention_type = "full"
|
330 |
+
self._cross_attention_type = "full"
|
331 |
+
|
332 |
+
def _get_self_attention_builder(self):
|
333 |
+
"""Return an instance of attention builder."""
|
334 |
+
raise NotImplementedError()
|
335 |
+
|
336 |
+
def _get_cross_attention_builder(self):
|
337 |
+
"""Return an instance of attention builder."""
|
338 |
+
raise NotImplementedError()
|
339 |
+
|
340 |
+
def _get_self_attention_layer_class(self):
|
341 |
+
"""Return a class to project the queries, keys and values to
|
342 |
+
multi-head versions."""
|
343 |
+
raise NotImplementedError()
|
344 |
+
|
345 |
+
def _get_cross_attention_layer_class(self):
|
346 |
+
"""Return a class to project the queries, keys and values to
|
347 |
+
multi-head versions."""
|
348 |
+
raise NotImplementedError()
|
349 |
+
|
350 |
+
def _get_decoder_class(self):
|
351 |
+
"""Return the class for the transformer decoder."""
|
352 |
+
raise NotImplementedError()
|
353 |
+
|
354 |
+
def _get_decoder_layer_class(self):
|
355 |
+
"""Return the class for the transformer decoder layer."""
|
356 |
+
raise NotImplementedError()
|
357 |
+
|
358 |
+
@property
|
359 |
+
def self_attention(self):
|
360 |
+
"""The attention builder instance that will be used for the self
|
361 |
+
attention modules."""
|
362 |
+
return self._self_attention_builder
|
363 |
+
|
364 |
+
@property
|
365 |
+
def self_attention_type(self):
|
366 |
+
"""The attention implementation used for self attention."""
|
367 |
+
return self._self_attention_type
|
368 |
+
|
369 |
+
@self_attention_type.setter
|
370 |
+
def self_attention_type(self, val):
|
371 |
+
if not self._self_attention_builder.validate_attention_type(val):
|
372 |
+
raise ValueError(("{!r} is not an available self attention "
|
373 |
+
"type").format(val))
|
374 |
+
self._self_attention_type = val
|
375 |
+
|
376 |
+
@property
|
377 |
+
def cross_attention(self):
|
378 |
+
"""The attention builder instance that will be used for the cross
|
379 |
+
attention modules."""
|
380 |
+
return self._cross_attention_builder
|
381 |
+
|
382 |
+
@property
|
383 |
+
def cross_attention_type(self):
|
384 |
+
"""The attention implementation used for cross attention."""
|
385 |
+
return self._cross_attention_type
|
386 |
+
|
387 |
+
@cross_attention_type.setter
|
388 |
+
def cross_attention_type(self, val):
|
389 |
+
if not self._cross_attention_builder.validate_attention_type(val):
|
390 |
+
raise ValueError(("{!r} is not an available cross attention "
|
391 |
+
"type").format(val))
|
392 |
+
self._cross_attention_type = val
|
393 |
+
|
394 |
+
def __setattr__(self, key, val):
|
395 |
+
# "protected" attributes are settable (probably from withing the class)
|
396 |
+
if key[0] == "_":
|
397 |
+
return super().__setattr__(key, val)
|
398 |
+
|
399 |
+
# Existing attributes are settable but they might also be attention
|
400 |
+
# parameters so try that as well
|
401 |
+
fail_on_exception = True
|
402 |
+
if hasattr(self, key):
|
403 |
+
super().__setattr__(key, val)
|
404 |
+
fail_on_exception = False
|
405 |
+
|
406 |
+
# Non-existing "public" attributes may be attention parameters
|
407 |
+
try:
|
408 |
+
setattr(self._self_attention_builder, key, val)
|
409 |
+
setattr(self._cross_attention_builder, key, val)
|
410 |
+
except:
|
411 |
+
if fail_on_exception:
|
412 |
+
raise
|
413 |
+
|
414 |
+
def get(self):
|
415 |
+
"""Build the transformer and return it."""
|
416 |
+
# Set the event dispatcher to attention builders
|
417 |
+
self.self_attention.event_dispatcher = self.event_dispatcher
|
418 |
+
self.cross_attention.event_dispatcher = self.event_dispatcher
|
419 |
+
|
420 |
+
# Extract into local variables the classes to be used
|
421 |
+
Decoder = self._get_decoder_class()
|
422 |
+
DecoderLayer = self._get_decoder_layer_class()
|
423 |
+
SelfAttention = self._get_self_attention_layer_class()
|
424 |
+
CrossAttention = self._get_cross_attention_layer_class()
|
425 |
+
|
426 |
+
model_dimensions = self.value_dimensions*self.n_heads
|
427 |
+
return Decoder(
|
428 |
+
[
|
429 |
+
DecoderLayer(
|
430 |
+
SelfAttention(
|
431 |
+
self.self_attention.get(self.self_attention_type),
|
432 |
+
model_dimensions,
|
433 |
+
self.n_heads,
|
434 |
+
d_keys=self.query_dimensions,
|
435 |
+
d_values=self.value_dimensions,
|
436 |
+
event_dispatcher=self.event_dispatcher
|
437 |
+
),
|
438 |
+
CrossAttention(
|
439 |
+
self.cross_attention.get(self.cross_attention_type),
|
440 |
+
model_dimensions,
|
441 |
+
self.n_heads,
|
442 |
+
d_keys=self.query_dimensions,
|
443 |
+
d_values=self.value_dimensions,
|
444 |
+
event_dispatcher=self.event_dispatcher
|
445 |
+
),
|
446 |
+
model_dimensions,
|
447 |
+
self.feed_forward_dimensions,
|
448 |
+
self.dropout,
|
449 |
+
self.activation,
|
450 |
+
event_dispatcher=self.event_dispatcher
|
451 |
+
)
|
452 |
+
for _ in range(self.n_layers)
|
453 |
+
],
|
454 |
+
(LayerNorm(model_dimensions) if self.final_normalization else None),
|
455 |
+
event_dispatcher=self.event_dispatcher
|
456 |
+
)
|
457 |
+
|
458 |
+
|
459 |
+
class TransformerDecoderBuilder(BaseTransformerDecoderBuilder):
|
460 |
+
"""Build a transformer decoder for training or processing of sequences all
|
461 |
+
elements at a time.
|
462 |
+
|
463 |
+
Example usage:
|
464 |
+
|
465 |
+
builder = TransformerDecoderBuilder()
|
466 |
+
builder.n_layers = 12
|
467 |
+
builder.n_heads = 8
|
468 |
+
builder.feed_forward_dimensions = 1024
|
469 |
+
builder.query_dimensions = 64
|
470 |
+
builder.value_dimensions = 64
|
471 |
+
builder.dropout = 0.1
|
472 |
+
builder.attention_dropout = 0.1
|
473 |
+
builder.self_attention_type = "full"
|
474 |
+
builder.cross_attention_type = "full"
|
475 |
+
transformer = builder.get()
|
476 |
+
"""
|
477 |
+
def _get_self_attention_builder(self):
|
478 |
+
"""Return an attention builder for creating non-recurrent attention
|
479 |
+
variants."""
|
480 |
+
return AttentionBuilder()
|
481 |
+
|
482 |
+
def _get_cross_attention_builder(self):
|
483 |
+
"""Return an attention builder for creating non-recurrent attention
|
484 |
+
variants."""
|
485 |
+
return AttentionBuilder()
|
486 |
+
|
487 |
+
def _get_self_attention_layer_class(self):
|
488 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
489 |
+
and values."""
|
490 |
+
return AttentionLayer
|
491 |
+
|
492 |
+
def _get_cross_attention_layer_class(self):
|
493 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
494 |
+
and values."""
|
495 |
+
return AttentionLayer
|
496 |
+
|
497 |
+
def _get_decoder_class(self):
|
498 |
+
"""Return the transformer decoder class."""
|
499 |
+
return TransformerDecoder
|
500 |
+
|
501 |
+
def _get_decoder_layer_class(self):
|
502 |
+
"""Return the transformer decoder layer class."""
|
503 |
+
return TransformerDecoderLayer
|
504 |
+
|
505 |
+
|
506 |
+
class RecurrentDecoderBuilder(BaseTransformerDecoderBuilder):
|
507 |
+
"""Build a transformer decoder for processing of sequences in
|
508 |
+
autoregressive fashion.
|
509 |
+
|
510 |
+
Example usage:
|
511 |
+
|
512 |
+
builder = RecurrentDecoderBuilder()
|
513 |
+
builder.n_layers = 12
|
514 |
+
builder.n_heads = 8
|
515 |
+
builder.feed_forward_dimensions = 1024
|
516 |
+
builder.query_dimensions = 64
|
517 |
+
builder.value_dimensions = 64
|
518 |
+
builder.dropout = 0.1
|
519 |
+
builder.attention_dropout = 0.1
|
520 |
+
builder.self_attention_type = "full"
|
521 |
+
builder.cross_attention_type = "full"
|
522 |
+
transformer = builder.get()
|
523 |
+
"""
|
524 |
+
def _get_self_attention_builder(self):
|
525 |
+
"""Return an attention builder for creating non-recurrent attention
|
526 |
+
variants."""
|
527 |
+
return RecurrentAttentionBuilder()
|
528 |
+
|
529 |
+
def _get_cross_attention_builder(self):
|
530 |
+
"""Return an attention builder for creating non-recurrent attention
|
531 |
+
variants."""
|
532 |
+
return RecurrentCrossAttentionBuilder()
|
533 |
+
|
534 |
+
def _get_self_attention_layer_class(self):
|
535 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
536 |
+
and values."""
|
537 |
+
return RecurrentAttentionLayer
|
538 |
+
|
539 |
+
def _get_cross_attention_layer_class(self):
|
540 |
+
"""Return the non-recurrent attention layer to project queries, keys
|
541 |
+
and values."""
|
542 |
+
return RecurrentCrossAttentionLayer
|
543 |
+
|
544 |
+
def _get_decoder_class(self):
|
545 |
+
"""Return the transformer decoder class."""
|
546 |
+
return RecurrentTransformerDecoder
|
547 |
+
|
548 |
+
def _get_decoder_layer_class(self):
|
549 |
+
"""Return the transformer decoder layer class."""
|
550 |
+
return RecurrentTransformerDecoderLayer
|
smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \
|
10 |
+
causal_dot_backward as causal_dot_backward_cpu
|
11 |
+
|
12 |
+
try:
|
13 |
+
from .causal_product_cuda import \
|
14 |
+
causal_dot_product as causal_dot_product_cuda, \
|
15 |
+
causal_dot_backward as causal_dot_backward_cuda
|
16 |
+
except ImportError:
|
17 |
+
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
18 |
+
|
19 |
+
|
20 |
+
class CausalDotProduct(torch.autograd.Function):
|
21 |
+
"""Compute the weighted sum of values but attending only to previous
|
22 |
+
values."""
|
23 |
+
dot = {
|
24 |
+
"cpu": causal_dot_product_cpu,
|
25 |
+
"cuda": causal_dot_product_cuda
|
26 |
+
}
|
27 |
+
dot_backward = {
|
28 |
+
"cpu": causal_dot_backward_cpu,
|
29 |
+
"cuda": causal_dot_backward_cuda
|
30 |
+
}
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def forward(ctx, Q, K, V):
|
34 |
+
# Save the inputs for the gradient computation
|
35 |
+
ctx.save_for_backward(Q, K, V)
|
36 |
+
|
37 |
+
# Create the output tensor
|
38 |
+
device = Q.device
|
39 |
+
N, H, L, _ = Q.shape
|
40 |
+
_, _, _, M = V.shape
|
41 |
+
product = torch.zeros((N, H, L, M), device=device)
|
42 |
+
|
43 |
+
# Actually perform the dot product
|
44 |
+
CausalDotProduct.dot[device.type](
|
45 |
+
Q.data,
|
46 |
+
K.data,
|
47 |
+
V.data,
|
48 |
+
product
|
49 |
+
)
|
50 |
+
|
51 |
+
return product
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def backward(ctx, grad_out):
|
55 |
+
# Extract the saved tensors
|
56 |
+
Q, K, V = ctx.saved_tensors
|
57 |
+
|
58 |
+
# Allocate memory for the gradients
|
59 |
+
grad_Q = torch.zeros_like(Q)
|
60 |
+
grad_K = torch.zeros_like(K)
|
61 |
+
grad_V = torch.zeros_like(V)
|
62 |
+
|
63 |
+
# Actually compute the gradients
|
64 |
+
CausalDotProduct.dot_backward[Q.device.type](
|
65 |
+
Q.data,
|
66 |
+
K.data,
|
67 |
+
V.data,
|
68 |
+
grad_out,
|
69 |
+
grad_Q,
|
70 |
+
grad_K,
|
71 |
+
grad_V
|
72 |
+
)
|
73 |
+
|
74 |
+
return grad_Q, grad_K, grad_V
|
75 |
+
|
76 |
+
|
77 |
+
# Alias the autograd functions to python style snake case naming
|
78 |
+
causal_dot_product = CausalDotProduct.apply
|
smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:84f32370e707beebd8fee88f356fb62721096142265895a5a8e9872063c04595
|
3 |
+
size 140928
|
smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py
ADDED
File without changes
|
smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>,
|
4 |
+
# Apoorv Vyas <[email protected]>
|
5 |
+
#
|
6 |
+
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .cluster_cpu import cluster as cluster_cpu
|
13 |
+
try:
|
14 |
+
from .cluster_cuda import cluster as cluster_gpu
|
15 |
+
except ImportError:
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
def cluster(
|
20 |
+
hashes,
|
21 |
+
lengths,
|
22 |
+
groups=None,
|
23 |
+
counts=None,
|
24 |
+
centroids=None,
|
25 |
+
distances=None,
|
26 |
+
bitcounts=None,
|
27 |
+
clusters=30,
|
28 |
+
iterations=10,
|
29 |
+
bits=32
|
30 |
+
):
|
31 |
+
"""Cluster hashes using a few iterations of K-Means with hamming distance.
|
32 |
+
|
33 |
+
All the tensors default initialized to None are optional buffers to avoid
|
34 |
+
memory allocations. distances and bitcounts are only used by the CUDA
|
35 |
+
version of this call. clusters will be ignored if centroids is provided.
|
36 |
+
|
37 |
+
Arguments
|
38 |
+
---------
|
39 |
+
hashes: A long tensor of shape (N, H, L) containing a hashcode for each
|
40 |
+
query.
|
41 |
+
lengths: An int tensor of shape (N,) containing the sequence length for
|
42 |
+
each sequence in hashes.
|
43 |
+
groups: An int tensor buffer of shape (N, H, L) contaning the cluster
|
44 |
+
in which the corresponding hash belongs to.
|
45 |
+
counts: An int tensor buffer of shape (N, H, K) containing the number
|
46 |
+
of elements in each cluster.
|
47 |
+
centroids: A long tensor buffer of shape (N, H, K) containing the
|
48 |
+
centroid for each cluster.
|
49 |
+
distances: An int tensor of shape (N, H, L) containing the distance to
|
50 |
+
the closest centroid for each hash.
|
51 |
+
bitcounts: An int tensor of shape (N, H, K, bits) containing the number
|
52 |
+
of elements that have 1 for a given bit.
|
53 |
+
clusters: The number of clusters to use for each sequence. It is
|
54 |
+
ignored if centroids is not None.
|
55 |
+
iterations: How many k-means iterations to perform.
|
56 |
+
bits: How many of the least-significant bits in hashes to consider.
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
groups and counts as defined above.
|
61 |
+
"""
|
62 |
+
device = hashes.device
|
63 |
+
N, H, L = hashes.shape
|
64 |
+
|
65 |
+
# Unfortunately cpu and gpu have different APIs so the entire call must be
|
66 |
+
# surrounded by an if-then-else
|
67 |
+
if device.type == "cpu":
|
68 |
+
if groups is None:
|
69 |
+
groups = torch.empty((N, H, L), dtype=torch.int32)
|
70 |
+
if centroids is None:
|
71 |
+
centroids = torch.empty((N, H, clusters), dtype=torch.int64)
|
72 |
+
centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
|
73 |
+
K = centroids.shape[2]
|
74 |
+
if counts is None:
|
75 |
+
counts = torch.empty((N, H, K), dtype=torch.int32)
|
76 |
+
|
77 |
+
cluster_cpu(
|
78 |
+
hashes, lengths,
|
79 |
+
centroids, groups, counts,
|
80 |
+
iterations, bits
|
81 |
+
)
|
82 |
+
|
83 |
+
return groups, counts
|
84 |
+
|
85 |
+
else:
|
86 |
+
if groups is None:
|
87 |
+
groups = torch.empty((N, H, L), dtype=torch.int32, device=device)
|
88 |
+
if centroids is None:
|
89 |
+
centroids = torch.empty((N, H, clusters), dtype=torch.int64,
|
90 |
+
device=device)
|
91 |
+
centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)]
|
92 |
+
K = centroids.numel() // N // H
|
93 |
+
#K = clusters
|
94 |
+
if counts is None:
|
95 |
+
counts = torch.empty((N, H, K), dtype=torch.int32, device=device)
|
96 |
+
if distances is None:
|
97 |
+
distances = torch.empty((N, H, L), dtype=torch.int32,
|
98 |
+
device=device)
|
99 |
+
if bitcounts is None:
|
100 |
+
bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32,
|
101 |
+
device=device)
|
102 |
+
groups = groups.view(N, H, L)
|
103 |
+
counts = counts.view(N, H, K)
|
104 |
+
centroids = centroids.view(N, H, K)
|
105 |
+
distances = distances.view(N, H, L)
|
106 |
+
bitcounts = bitcounts.view(N, H, K, -1)
|
107 |
+
|
108 |
+
cluster_gpu(
|
109 |
+
hashes, lengths,
|
110 |
+
centroids, distances, bitcounts, groups, counts,
|
111 |
+
iterations, bits
|
112 |
+
)
|
113 |
+
|
114 |
+
return groups, counts
|
115 |
+
|
smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2bd8f761d6e1efdeea33665cad8702b5c07d1a0db728d19cf332c4383510d45
|
3 |
+
size 139824
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""This module implements a basic event system that allows the transformer
|
7 |
+
internal components to make available any tensor with minimal overhead."""
|
8 |
+
|
9 |
+
from .event import Event, AttentionEvent, QKVEvent
|
10 |
+
from .event_dispatcher import EventDispatcher
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (556 Bytes). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc
ADDED
Binary file (2.21 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc
ADDED
Binary file (3.5 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc
ADDED
Binary file (5.82 kB). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/events/event.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
|
7 |
+
class Event(object):
|
8 |
+
"""The Event is the base class for all events that are dispatched from any
|
9 |
+
transformer module.
|
10 |
+
|
11 |
+
This class defines only the basic attributes of an event without any
|
12 |
+
payload.
|
13 |
+
|
14 |
+
Arguments
|
15 |
+
---------
|
16 |
+
source: torch.nn.Module instance that dispatched this event
|
17 |
+
"""
|
18 |
+
def __init__(self, source):
|
19 |
+
self.source = source
|
20 |
+
|
21 |
+
|
22 |
+
class AttentionEvent(Event):
|
23 |
+
"""An event containing an attention matrix.
|
24 |
+
|
25 |
+
Arguments
|
26 |
+
---------
|
27 |
+
source: torch.nn.Module instance that dispatched this event
|
28 |
+
attention_matrix: torch.tensor of the multihead attention matrix
|
29 |
+
computed in the corresponding attention layer
|
30 |
+
"""
|
31 |
+
def __init__(self, source, attention_matrix):
|
32 |
+
super(AttentionEvent, self).__init__(source)
|
33 |
+
self.attention_matrix = attention_matrix
|
34 |
+
|
35 |
+
|
36 |
+
class QKVEvent(Event):
|
37 |
+
"""An event containing the queries, keys and values projected in their
|
38 |
+
multiple heads.
|
39 |
+
|
40 |
+
Arguments
|
41 |
+
---------
|
42 |
+
source: torch.nn.Module instance that dispatched this event
|
43 |
+
queries: torch.tensor containing the queries in shape NLHE
|
44 |
+
keys: torch.tensor containing the keys in shape NSHE
|
45 |
+
values: torch.tensor containing the values in shape NSHD
|
46 |
+
"""
|
47 |
+
def __init__(self, source, queries, keys, values):
|
48 |
+
super(QKVEvent, self).__init__(source)
|
49 |
+
self.queries = queries
|
50 |
+
self.keys = keys
|
51 |
+
self.values = values
|
smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
from .event import Event
|
9 |
+
from .filters import event_class
|
10 |
+
|
11 |
+
|
12 |
+
class EventDispatcher(object):
|
13 |
+
"""An EventDispatcher is a simple way to implement an observer pattern for
|
14 |
+
loose coupling of components. In our case it is used so that the internals
|
15 |
+
of large neural networks can communicate with the outside world in an
|
16 |
+
agnostic and efficient way.
|
17 |
+
|
18 |
+
Example usage
|
19 |
+
-------------
|
20 |
+
|
21 |
+
from fast_transformers.events import EventDispatcher, AttentionEvent
|
22 |
+
from fast_transformers.events.filters import \
|
23 |
+
layer_name_contains
|
24 |
+
|
25 |
+
def attention_event_handler(event):
|
26 |
+
print(event.attention_matrix)
|
27 |
+
|
28 |
+
ed = EventDispatcher()
|
29 |
+
ed.listen(AttentionEvent, attention_event_handler)
|
30 |
+
ed.listen(
|
31 |
+
AttentionEvent & layer_name_contains("layers.12"),
|
32 |
+
attention_event_handler
|
33 |
+
)
|
34 |
+
"""
|
35 |
+
_dispatchers = {}
|
36 |
+
|
37 |
+
def __init__(self):
|
38 |
+
self._listeners = OrderedDict()
|
39 |
+
|
40 |
+
def listen(self, event_filter, event_handler):
|
41 |
+
"""Add an event handler for the events that pass the event filter.
|
42 |
+
|
43 |
+
Arguments
|
44 |
+
---------
|
45 |
+
event_filter: callable or Event class to define for which events
|
46 |
+
this handler will be called
|
47 |
+
event_handler: callable that accepts an instance of Event
|
48 |
+
"""
|
49 |
+
if isinstance(event_filter, type) and issubclass(event_filter, Event):
|
50 |
+
event_filter = event_class(event_filter)
|
51 |
+
|
52 |
+
self._listeners[event_handler] = event_filter
|
53 |
+
|
54 |
+
def remove(self, event_handler):
|
55 |
+
"""Remove the event_handler from the listeners so that no more events
|
56 |
+
are dispatched to this handler."""
|
57 |
+
self._listeners.pop(event_handler, None)
|
58 |
+
|
59 |
+
def clear(self):
|
60 |
+
"""Remove all listeners from the event dispatcher."""
|
61 |
+
self._listeners.clear()
|
62 |
+
|
63 |
+
def dispatch(self, event):
|
64 |
+
"""Dispatch an event to the listeners.
|
65 |
+
|
66 |
+
Arguments
|
67 |
+
---------
|
68 |
+
event: Event instance
|
69 |
+
"""
|
70 |
+
for event_handler, event_filter in self._listeners.items():
|
71 |
+
if event_filter(event):
|
72 |
+
event_handler(event)
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def get(cls, key=""):
|
76 |
+
"""Factory method for creating global event dispatchers for loosely
|
77 |
+
coupling parts of a larger codebase.
|
78 |
+
|
79 |
+
Since global objects are a complete antipattern, we suggest that this
|
80 |
+
is only used to set a default value for an event dispatcher passed as
|
81 |
+
an argument.
|
82 |
+
|
83 |
+
Argument
|
84 |
+
--------
|
85 |
+
key: A key to uniquely identify a dispatcher or an instance of a
|
86 |
+
dispatcher to be returned as is
|
87 |
+
"""
|
88 |
+
if isinstance(key, cls):
|
89 |
+
return key
|
90 |
+
if key not in cls._dispatchers:
|
91 |
+
cls._dispatchers[key] = cls()
|
92 |
+
return cls._dispatchers[key]
|
smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""Define composable functions to filter events."""
|
7 |
+
|
8 |
+
import weakref
|
9 |
+
|
10 |
+
from .event import Event
|
11 |
+
|
12 |
+
|
13 |
+
class EventFilter(object):
|
14 |
+
"""EventFilter instances are predicates (ie functions that return True or
|
15 |
+
False) to be used with an event dispatcher for filtering event
|
16 |
+
instances.
|
17 |
+
|
18 |
+
The main benefit from using raw functions is that an EventFilter composes
|
19 |
+
very easily using operators such as &, |, ~.
|
20 |
+
|
21 |
+
Example
|
22 |
+
--------
|
23 |
+
|
24 |
+
event_filter = AttentionEvent | layer_name_contains("layers.1")
|
25 |
+
event_filter = from_layer(transformer.layers[2].attention)
|
26 |
+
event_filter = (
|
27 |
+
AttentionEvent &
|
28 |
+
lambda ev: torch.isnan(ev.attention_matrix).any()
|
29 |
+
)
|
30 |
+
"""
|
31 |
+
def __call__(self, event):
|
32 |
+
raise NotImplementedError()
|
33 |
+
|
34 |
+
def _to_event_filter(self, other):
|
35 |
+
if isinstance(other, EventFilter):
|
36 |
+
return other
|
37 |
+
if isinstance(other, type) and issubclass(other, Event):
|
38 |
+
return event_class(other)
|
39 |
+
if callable(other):
|
40 |
+
return CallableEventFilter(other)
|
41 |
+
|
42 |
+
return NotImplemented
|
43 |
+
|
44 |
+
def __and__(self, other):
|
45 |
+
other = self._to_event_filter(other)
|
46 |
+
if other is NotImplemented:
|
47 |
+
return other
|
48 |
+
return CallableEventFilter(lambda ev: self(ev) and other(ev))
|
49 |
+
|
50 |
+
def __rand__(self, other):
|
51 |
+
other = self._to_event_filter(other)
|
52 |
+
if other is NotImplemented:
|
53 |
+
return other
|
54 |
+
return CallableEventFilter(lambda ev: other(ev) and self(ev))
|
55 |
+
|
56 |
+
def __or__(self, other):
|
57 |
+
other = self._to_event_filter(other)
|
58 |
+
if other is NotImplemented:
|
59 |
+
return other
|
60 |
+
return CallableEventFilter(lambda ev: self(ev) or other(ev))
|
61 |
+
|
62 |
+
def __ror__(self, other):
|
63 |
+
other = self._to_event_filter(other)
|
64 |
+
if other is NotImplemented:
|
65 |
+
return other
|
66 |
+
return CallableEventFilter(lambda ev: other(ev) or self(ev))
|
67 |
+
|
68 |
+
def __invert__(self):
|
69 |
+
return CallableEventFilter(lambda ev: not self(ev))
|
70 |
+
|
71 |
+
|
72 |
+
class CallableEventFilter(EventFilter):
|
73 |
+
"""Wrap a function with an EventFilter object."""
|
74 |
+
def __init__(self, event_filter):
|
75 |
+
self._event_filter = event_filter
|
76 |
+
|
77 |
+
def __call__(self, event):
|
78 |
+
return self._event_filter(event)
|
79 |
+
|
80 |
+
|
81 |
+
class LayerNameEventFilter(EventFilter):
|
82 |
+
"""A LayerNameEventFilter allows to filter events based on a human readable
|
83 |
+
name of the layer that emitted them.
|
84 |
+
|
85 |
+
Note that LayerNameEventFilter keeps a weak reference to all modules which
|
86 |
+
means that it cannot be used to prevent modules from being garbage
|
87 |
+
collected.
|
88 |
+
|
89 |
+
Arguments
|
90 |
+
---------
|
91 |
+
root: torch.nn.Module instance that represents the root container
|
92 |
+
name_filter: callable, that returns true if the name
|
93 |
+
"""
|
94 |
+
def __init__(self, root, name_filter):
|
95 |
+
self._names = {
|
96 |
+
weakref.ref(m): n
|
97 |
+
for n, m in root.named_modules()
|
98 |
+
}
|
99 |
+
self._name_filter = name_filter
|
100 |
+
|
101 |
+
def __call__(self, event):
|
102 |
+
name = self._names.get(weakref.ref(event.source), None)
|
103 |
+
if name is None:
|
104 |
+
return False
|
105 |
+
return self._name_filter(name)
|
106 |
+
|
107 |
+
|
108 |
+
def event_class(klass):
|
109 |
+
"""Select events that are instances of `klass`.
|
110 |
+
|
111 |
+
Arguments
|
112 |
+
---------
|
113 |
+
klass: A class to check the event instance against
|
114 |
+
|
115 |
+
Returns
|
116 |
+
-------
|
117 |
+
An instance of EventFilter
|
118 |
+
"""
|
119 |
+
return CallableEventFilter(lambda ev: isinstance(ev, klass))
|
120 |
+
|
121 |
+
|
122 |
+
def from_layer(layer):
|
123 |
+
"""Select events that are dispatched from the `layer`.
|
124 |
+
|
125 |
+
Arguments
|
126 |
+
---------
|
127 |
+
layer: An instance of torch.nn.Module to check against the event source
|
128 |
+
|
129 |
+
Returns
|
130 |
+
-------
|
131 |
+
An instance of EventFilter
|
132 |
+
"""
|
133 |
+
return CallableEventFilter(lambda ev: ev.source is layer)
|
134 |
+
|
135 |
+
|
136 |
+
def layer_name_contains(root, name):
|
137 |
+
"""Select events that contain `name` in their human readable name.
|
138 |
+
|
139 |
+
We use root.named_modules() to get human readable names for the layers.
|
140 |
+
"""
|
141 |
+
return LayerNameEventFilter(root, lambda n: name in n)
|
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <[email protected]>
|
4 |
+
#
|
5 |
+
|
6 |
+
"""Implementations of feature maps to be used with linear attention and causal
|
7 |
+
linear attention."""
|
8 |
+
|
9 |
+
|
10 |
+
from .base import elu_feature_map, ActivationFunctionFeatureMap
|
11 |
+
from .fourier_features import RandomFourierFeatures, Favor, \
|
12 |
+
SmoothedRandomFourierFeatures, GeneralizedRandomFeatures
|
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (614 Bytes). View file
|
|
smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc
ADDED
Binary file (3.42 kB). View file
|
|