diff --git a/smi-ted/inference/smi_ted_light/.gitattributes b/smi-ted/inference/smi_ted_light/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..3dc3ae1ee5cb872025a64f68ae06558b5d680949 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/.gitattributes @@ -0,0 +1,2 @@ +smi-ted/inference/smi_ted_light/fast_transformers/**/*.so filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9a3cb0a863fcd1b75fb1c54c3f72b1c9380fea --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Provide a library with fast transformer implementations.""" + +__author__ = "Angelos Katharopoulos, Apoorv Vyas" +__copyright__ = "Copyright (c) 2020 Idiap Research Institute" +__license__ = "MIT" +__maintainer__ = "Angelos Katharopoulos, Apoorv Vyas" +__email__ = "angelos.katharopoulos@idiap.ch, avyas@idiap.ch" +__url__ = "https://github.com/idiap/fast-transformers" +__version__ = "0.4.0" diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd212eb9f7d68ebb8ef083be0be37c764f59a905 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/aggregate/__init__.py @@ -0,0 +1,128 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + + +import torch + +from .aggregate_cpu import aggregate as aggregate_cpu, \ + broadcast as broadcast_cpu +try: + from .aggregate_cuda import aggregate as aggregate_gpu, \ + broadcast as broadcast_gpu + from .clustered_aggregate_cuda import \ + clustered_broadcast as clustered_broadcast_gpu, \ + clustered_aggregate as clustered_aggregate_gpu + +except ImportError: + pass + + +def aggregate(X, G, F, Y=None): + device = X.device + if Y is None: + Y = torch.zeros( + F.shape + (X.shape[-1],), + device=device, + dtype=X.dtype + ) + else: + Y.zero_() + + if device.type == "cpu": + aggregate_cpu(X, G, F, Y) + else: + aggregate_gpu(X, G, F, Y) + + return Y + + +def broadcast(Y, G, F, X=None): + device = Y.device + if X is None: + X = torch.zeros( + G.shape + (Y.shape[-1],), + device=device, + dtype=Y.dtype + ) + + if device.type == "cpu": + broadcast_cpu(Y, G, F, X) + else: + broadcast_gpu(Y, G, F, X) + + return X + + +# Divide the cluster into groups of equal size +# as constrained by the shared memory +def set_group(C, E): + C_per_block = int(192 * 64 / (E+1)) + G_min = (C + C_per_block - 1) // C_per_block + for G in range(G_min, C+1): + if C % G == 0: + return G + + +def clustered_broadcast(Y, groups, counts, factors, X=None): + device = Y.device + if X is None: + X = torch.zeros( + groups.shape + (Y.shape[-1],), + device=device, + dtype=Y.dtype + ) + if device.type == "cpu": + broadcast_cpu(Y, groups, factors, X) + else: + N, H, C, E = Y.shape + _, _, L, _ = X.shape + + # Following are some booking keeping parameters to facilitate the + # broadcast kernel that takes advantage of clustering + # More information can be found in the cuda file + with torch.no_grad(): + threads = 256 + G = set_group(C, E) + group_counts = counts.view(N, H, G, -1).sum(-1) + block_counts = (group_counts + threads - 1) // threads + total_blocks = block_counts.sum().item() + indx_maps = torch.ones( + (total_blocks, 5), + device=X.device, + dtype=torch.int32 + ) + + clustered_broadcast_gpu( + Y, + groups, + factors, + X, + block_counts.int(), + group_counts.int(), + threads, + G, + total_blocks, + indx_maps + ) + return X + + +def clustered_aggregate(X, G, F, lengths, Y=None): + device = X.device + if Y is None: + Y = torch.zeros( + F.shape + (X.shape[-1],), + device=device, + dtype=X.dtype + ) + else: + Y.zero_() + + if device.type == "cpu": + aggregate_cpu(X, G, F, Y) + else: + clustered_aggregate_gpu(X, G, F, lengths, Y) + return Y diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..0a581de61438b08346ef49d61dbd7f4e4ba84a82 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/aggregate/aggregate_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6bccb1a374d4649aaef6361cc41c9ffb471086464cc07a0d6d21c5b65adb0711 +size 138248 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23767fba238c32c7f98e16a63cad73710834ebf7 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implementations of different types of attention mechanisms.""" + + +from .attention_layer import AttentionLayer +from .full_attention import FullAttention +from .linear_attention import LinearAttention +#from .causal_linear_attention import CausalLinearAttention +#from .clustered_attention import ClusteredAttention +#from .improved_clustered_attention import ImprovedClusteredAttention +#from .reformer_attention import ReformerAttention +#from .conditional_full_attention import ConditionalFullAttention +#from .exact_topk_attention import ExactTopKAttention +#from .improved_clustered_causal_attention import ImprovedClusteredCausalAttention +#from .local_attention import LocalAttention diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b7068036c624f4de3b73f6264048755286110d9 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4457b37564e30d4c82d0787182bb38c9ae56d7cf Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/attention_layer.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b0799c1c43b534640086250dc81d71ff804afb3 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/full_attention.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2007c964832e2ab1619ae08265cd5004f1fc86ad Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention/__pycache__/linear_attention.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..22e5eea16d89c68c501566fd3c005e44ec1550eb --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/attention_layer.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""The base attention layer performs all the query key value projections and +output projections leaving the implementation of the attention to the inner +attention module. + +The transformer layers, however, are agnostic of the attention implementation +and any layer that implements the same interface can substitute for the +attention layer. +""" + +from torch.nn import Linear, Module + +from ..events import EventDispatcher, QKVEvent + + +class AttentionLayer(Module): + """Implement the attention layer. Namely project the inputs to multi-head + queries, keys and values, call the attention implementation and then + reproject the output. + + It can be thought of as a decorator (see decorator design patter) of an + attention layer. + + Arguments + --------- + attention: Specific inner attention implementation that just computes a + weighted average of values given a similarity of queries and + keys. + d_model: The input feature dimensionality + n_heads: The number of heads for the multi head attention + d_keys: The dimensionality of the keys/queries + (default: d_model/n_heads) + d_values: The dimensionality of the values (default: d_model/n_heads) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(AttentionLayer, self).__init__() + + # Fill d_keys and d_values + d_keys = d_keys or (d_model//n_heads) + d_values = d_values or (d_model//n_heads) + + self.inner_attention = attention + self.query_projection = Linear(d_model, d_keys * n_heads) + self.key_projection = Linear(d_model, d_keys * n_heads) + self.value_projection = Linear(d_model, d_values * n_heads) + self.out_projection = Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """Apply attention to the passed in queries/keys/values after + projecting them to multiple heads. + + In the argument description we make use of the following sizes + + - N: the batch size + - L: The maximum length of the queries + - S: The maximum length of the keys (the actual length per sequence + is given by the length mask) + - D: The input feature dimensionality passed in the constructor as + 'd_model' + + Arguments + --------- + queries: (N, L, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + + Returns + ------- + The new value for each query as a tensor of shape (N, L, D). + """ + # Extract the dimensions into local variables + N, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + # Project the queries/keys/values + queries = self.query_projection(queries).view(N, L, H, -1) + keys = self.key_projection(keys).view(N, S, H, -1) + values = self.value_projection(values).view(N, S, H, -1) + + # Let the world know of the qkv + self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) + + # Compute the attention + new_values = self.inner_attention( + queries, + keys, + values, + attn_mask, + query_lengths, + key_lengths + ).view(N, L, -1) + + # Project the output and return + return self.out_projection(new_values) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..61d1a87f7257c84003b2c30348e2b8eec7d7bbbe --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/causal_linear_attention.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement causally masked linear attention.""" + +import torch +from torch.nn import Module + +from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \ + EventDispatcherInstance +from ..events import EventDispatcher +from ..causal_product import causal_dot_product +from ..feature_maps import elu_feature_map + + +def causal_linear(Q, K, V): + Q = Q.permute(0,2,1,3).contiguous() + K = K.permute(0,2,1,3).contiguous() + V = V.permute(0,2,1,3).contiguous() + V_new = causal_dot_product(Q, K, V) + return V_new.permute(0,2,1,3).contiguous() + + +class CausalLinearAttention(Module): + """Implement causally masked attention using dot product of feature maps in + O(N D^2) complexity. + + See fast_transformers.attention.linear_attention.LinearAttention for the + general concept of replacing the softmax with feature maps. In addition to + that, we also make use of the fact that causal masking is a triangular mask + which allows us to apply the masking and still compute the attention in O(N + D^2) complexity. + + Arguments + --------- + feature_map: callable, a callable that applies the feature map to the + last dimension of a tensor (default: elu(x)+1) + eps: float, a small number to ensure the numerical stability of the + denominator (default: 1e-6) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, query_dimensions, feature_map=None, eps=1e-6, + event_dispatcher=""): + super(CausalLinearAttention, self).__init__() + self.feature_map = ( + feature_map(query_dimensions) if feature_map else + elu_feature_map(query_dimensions) + ) + self.eps = eps + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def _make_sizes_compatible(self, Q, K): + """Either slice or pad K in case that the sizes do not match between Q + and K.""" + N, L, H, E = Q.shape + _, S, _, _ = K.shape + if L == S: + return Q, K + + if L < S: + return Q, K[:, :L, :, :] + + if L > S: + return Q, torch.cat([K, K.new_zeros(N, L-S, H, E)], dim=1) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Apply the feature map to the queries and keys + self.feature_map.new_feature_map(queries.device) + Q = self.feature_map.forward_queries(queries) + K = self.feature_map.forward_keys(keys) + + # Apply the key padding mask and make sure the attn_mask is a + # lower triangular causal mask + if not attn_mask.lower_triangular: + raise RuntimeError(("CausalLinearAttention only supports full " + "lower triangular masks")) + K = K * key_lengths.float_matrix[:, :, None, None] + + # Ensure that Q and K have compatible sizes for the following + # computations, namely L == S + Q, K = self._make_sizes_compatible(Q, K) + + # TODO: Shall we divide the Q and K with a relatively large number to + # avoid numerical instabilities in computing the denominator? + # We used to divide each with the max norm of all q and k but + # that seems relatively costly for a simple normalization. + + # Compute the normalizers + Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps) + + # Compute the unnormalized result + V = causal_linear( + Q, + K, + values + ) + + return V * Z[:, :, :, None] + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "causal-linear", CausalLinearAttention, + [ + ("query_dimensions", Int), + ("feature_map", Optional(Callable)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1150a654f80fed289ae5d82b289e3b86794ceec2 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/clustered_attention.py @@ -0,0 +1,195 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement clustered self attention.""" + +from math import sqrt + +import torch +import torch.autograd +from torch.nn import Dropout, Module +from torch.nn.init import normal_ + +from ..attention_registry import AttentionRegistry, Optional, Float, Int, \ + Bool, EventDispatcherInstance +from ..events import EventDispatcher +from ..masking import FullMask +from ..aggregate import clustered_aggregate, clustered_broadcast +from ..clustering.hamming import cluster +from ..hashing import compute_hashes + + +class _GroupQueries(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, clusters, counts, lengths): + factors = 1./counts.float() + q_grouped = clustered_aggregate(Q, clusters, factors, lengths) + ctx.save_for_backward(clusters, counts, factors) + + return q_grouped + + @staticmethod + def backward(ctx, grad_q_grouped): + clusters, counts, factors = ctx.saved_tensors + grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors) + + return grad_q, None, None, None + + +class _BroadcastValues(torch.autograd.Function): + @staticmethod + def forward(ctx, v_grouped, clusters, counts, lengths): + factors = torch.ones_like(counts, dtype=v_grouped.dtype) + V = clustered_broadcast(v_grouped, clusters, counts, factors) + ctx.save_for_backward(clusters, counts, factors, lengths) + + return V + + @staticmethod + def backward(ctx, grad_v): + clusters, counts, factors, lengths = ctx.saved_tensors + grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths) + + return grad_v_grouped, None, None, None + + +class ClusteredAttention(Module): + """Use LSH and clustering in the resulting Hamming space to group queries + that will have minimal L2 distance from each other. + + Given the queries, keys, and values as Q, K, and V respectively, we + first cluster the queries in "C" groups and compute the "C" query centroids + Q_c. + + We now use to the centroids Q_c to compute the attention using: + + V'_c = softmax(Q_c.mm(K.t()), dim=-1).mm(V). + + Now the computed values V'_c are "broadcasted" back to the query members + of the corresponding cluster. + + Arguments + --------- + clusters: How many clusters to group the queries into + iterations: The number of lloyd iterations to perform (default: 10) + bits: How many bits to use for the hash (default: 32) + hash_bias: If true, hamming distance proportional to L2 distance + If false, hamming distance proportional to cosine distance + (default: True) + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, clusters, iterations=10, bits=32, + hash_bias=True, softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(ClusteredAttention, self).__init__() + self.clusters = clusters + self.iterations = iterations + self.bits = bits + self.hash_bias = hash_bias + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def _create_query_groups(self, Q, query_lengths): + N, H, L, E = Q.shape + + # Compute the hashes for all the queries + planes = Q.new_empty((self.bits, E+1)) + normal_(planes) + if not self.hash_bias: + planes[:, -1] = 0 + hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L) + + # Cluster the hashes and return the cluster index per query + clusters, counts = cluster( + hashes, + query_lengths._lengths.int(), + clusters=self.clusters, + iterations=self.iterations, + bits=self.bits + ) + sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1) + return (sorted_clusters, counts), sorted_indx + + def _group_queries(self, Q, groups, lengths): + """Aggregate the Qs based on the index of cluster they belong to. Make + sure to allow for gradient propagation backwards from the grouped + queries to each query.""" + q_grouped = _GroupQueries.apply(Q, *groups, lengths) + return q_grouped + + def _broadcast_values(self, V, groups, lengths): + """Broadcast the values back to the correct positions but make sure + that the gradient flows properly.""" + V_new = _BroadcastValues.apply(V.contiguous(), *groups, lengths) + return V_new + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Make sure that there is no attention mask + assert attn_mask.all_ones, ("Clustered attention cannot use an " + "arbitrary attention mask.") + + queries = queries.permute(0,2,1,3).contiguous() + keys = keys.permute(0,2,1,3).contiguous() + values = values.permute(0,2,1,3).contiguous() + + N, H, L, E = queries.shape + _, _, S, D = values.shape + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Cluster the queries into groups + groups, sorted_indx = self._create_query_groups(queries, query_lengths) + # Re-organize queries so that first group belong to first cluster + # next to second cluster and so on. This improves kernel implementations. + # Note that this step is introduced after NeurIPS submission and + # now the complexity is O(N log(N)). + q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L + q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1) + s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E) + + # Aggregate the re-arranged queries. + Q_grouped = self._group_queries(s_queries, groups, query_lengths._lengths.int()) + # Compute the attention + QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys) + QK = QK + key_lengths.additive_matrix[:, None, None, :] + A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) + V = torch.einsum("nhls,nhsd->nhld", A, values) + + # Broadcast grouped attention + V_broadcast = self._broadcast_values(V, groups, query_lengths._lengths.int()) + + # Reverse the previous mapping + rev_indx = torch.argsort(sorted_indx, dim=-1) + q_rev_flat = (rev_indx.view(N*H, -1) + q_offset).reshape(-1) + V_new = V_broadcast.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D) + V_new = V_new.permute(0, 2, 1, 3).contiguous() + return V_new + + + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "clustered", ClusteredAttention, + [ + ("clusters", Int), + ("iterations", Optional(Int, 10)), + ("bits", Optional(Int, 63)), + ("hash_bias", Optional(Bool, True)), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..49542a5b3a037fe886ccfa0df309912f4e6ac6e8 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/conditional_full_attention.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement a self attention that delegates to full attention or another +attention depending on the input sequence length.""" + +import torch +from torch.nn import Module + +from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ + EventDispatcherInstance +from ..events import EventDispatcher +from .full_attention import FullAttention + + +class ConditionalFullAttention(Module): + """"Delegate to full attention if the input sequence is short. + + Arguments + --------- + other_attention: Use the passed attention module if the sequence is + longer than 'length_limit'. + length_limit: An integer denoting the maximum sequence length to + consider. + softmax_temp: See fast_transformers.attention.full_attention. + attention_dropout: See fast_transformers.attention.full_attention. + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, other_attention, length_limit=512, softmax_temp=None, + attention_dropout=0.1, event_dispatcher=""): + super(ConditionalFullAttention, self).__init__() + self.full_attention = FullAttention(softmax_temp, attention_dropout) + self.other_attention = other_attention + self.length_limit = length_limit + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Extract some shapes to compare with the length limit + L = queries.shape[1] + S = values.shape[1] + + if L > self.length_limit or S > self.length_limit: + return self.other_attention(queries, keys, values, attn_mask, + query_lengths, key_lengths) + else: + return self.full_attention(queries, keys, values, attn_mask, + query_lengths, key_lengths) + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "conditional-full", ConditionalFullAttention, + [ + ("length_limit", Optional(Int, 512)), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c498121d40391656d11614cc8e86aec889f23500 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/exact_topk_attention.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement the oracle top-k attention. The top-k keys are exact ones. +MultiHeadAttention module. Note that this module is to be used in conjuction +with the AttentionLayer in order to work.""" + +from math import sqrt + +import torch +from torch.nn import Dropout, Module + +from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ + EventDispatcherInstance +from ..events import EventDispatcher + + +class ExactTopKAttention(Module): + """Implement the oracle top-k softmax attention. + + Arguments + --------- + top-k: The top k keys to attend to (default: 32) + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, topk=32, softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(ExactTopKAttention, self).__init__() + self.topk = topk + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Extract some shapes and compute the temperature + N, L, H, E = queries.shape + _, S, _, D = values.shape + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhe,nshe->nhls", queries, keys) + topk = min(self.topk, S) + + if not attn_mask.all_ones: + QK = QK + attn_mask.additive_matrix + QK = QK + key_lengths.additive_matrix[:, None, None] + + topk_values, topk_idx = torch.topk(QK, topk, sorted=False, dim=-1) + mask = QK.new_ones(QK.shape) * float("-inf") + mask[ + torch.arange(N, device=QK.device).view(N, 1, 1, 1), + torch.arange(H, device=QK.device).view(1, H, 1, 1), + torch.arange(L, device=QK.device).view(1, 1, L, 1), + topk_idx, + ] = 0. + + QK = QK + mask + + # Compute the attention and the weighted average + A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) + V = torch.einsum("nhls,nshd->nlhd", A, values) + + # Make sure that what we return is contiguous + return V.contiguous() + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "exact-topk", ExactTopKAttention, + [ + ("topk", Optional(Int, 32)), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..68764ee1718a04cd6b6dde5b891e4ca1e1b4d38d --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/full_attention.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement the full attention similar to the one implemented by PyTorch's +MultiHeadAttention module. Note that this module is to be used in conjuction +with the `fast_transformers.attention.attention_layer.AttentionLayer` in order +to work.""" + +from math import sqrt + +import torch +from torch.nn import Dropout, Module + +from ..attention_registry import AttentionRegistry, Optional, Float, \ + EventDispatcherInstance +from ..events import EventDispatcher, AttentionEvent + + +class FullAttention(Module): + """Implement the scaled dot product attention with softmax. + + Arguments + --------- + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(FullAttention, self).__init__() + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """Implements the multihead softmax attention. + + Arguments + --------- + queries: (N, L, H, E) The tensor containing the queries + keys: (N, S, H, E) The tensor containing the keys + values: (N, S, H, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + """ + # Extract some shapes and compute the temperature + N, L, H, E = queries.shape + _, S, _, D = values.shape + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Scale the queries instead of applying the softmax temperature to the + # dot products + queries = queries * softmax_temp + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhe,nshe->nhls", queries, keys) + if not attn_mask.all_ones: + QK = QK + attn_mask.additive_matrix + if not key_lengths.all_ones: + QK = QK + key_lengths.additive_matrix[:, None, None] + + # Compute the attention and the weighted average + A = self.dropout(torch.softmax(QK, dim=-1)) + V = torch.einsum("nhls,nshd->nlhd", A, values) + + # Let the world know of the attention matrix + self.event_dispatcher.dispatch(AttentionEvent(self, A)) + + # Make sure that what we return is contiguous + return V.contiguous() + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "full", FullAttention, + [ + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..32683c0e135599b26c9b4ff9c2a1f2a9e40407c4 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_attention.py @@ -0,0 +1,268 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement improved clustered self attention.""" + +from math import sqrt + +import torch +import torch.autograd +from torch.nn import Dropout, Module +from torch.nn.init import normal_ + +from ..attention_registry import AttentionRegistry, Optional, Float, Int, \ + Bool, EventDispatcherInstance +from ..events import EventDispatcher +from ..masking import FullMask +from ..aggregate import clustered_aggregate, clustered_broadcast +from ..clustering.hamming import cluster +from ..hashing import compute_hashes +from ..sparse_product import sparse_dot_product, sparse_weighted_average +from ..sparse_product import clustered_sparse_dot_product, \ + clustered_sparse_weighted_average + + +class _GroupQueries(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, clusters, counts, lengths): + factors = 1./counts.float() + q_grouped = clustered_aggregate(Q, clusters, factors, lengths) + ctx.save_for_backward(clusters, counts, factors) + + return q_grouped + + @staticmethod + def backward(ctx, grad_q_grouped): + clusters, counts, factors = ctx.saved_tensors + grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors) + + return grad_q, None, None, None + + +class _BroadcastValues(torch.autograd.Function): + @staticmethod + def forward(ctx, v_grouped, clusters, counts, lengths): + factors = torch.ones_like(counts, dtype=v_grouped.dtype) + V = clustered_broadcast(v_grouped, clusters, counts, factors) + ctx.save_for_backward(clusters, counts, factors, lengths) + + return V + + @staticmethod + def backward(ctx, grad_v): + clusters, counts, factors, lengths = ctx.saved_tensors + grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths) + + return grad_v_grouped, None, None, None, None + + +class ImprovedClusteredAttention(Module): + """ + Immproved clustered attention approximation by recompution attention + for each query with the top-k keys for the corresponding cluster. + + Given the queries, keys, and values as Q, K, and V respectively, we + first cluster the queries in "C" groups and compute the "C" query centroids + Q_c. + + We now use to the centroids Q_c to identify the top-k keys with highest + dot products. + + Subsequently, for each query we compute the sparse dot product with + the corresponding top-k keys to improve the attention approximation. + + Arguments + --------- + clusters: How many clusters to group the queries into + iterations: The number of lloyd iterations to perform (default: 10) + bits: How many bits to use for the hash (default: 32) + hash_bias: If true, hamming distance proportional to L2 distance + If false, hamming distance proportional to cosine distance + (default: True) + topk: Number of top-k keys to for improved approximation (default: 32) + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, clusters, iterations=10, bits=32, + hash_bias=True, topk=32, softmax_temp=None, + attention_dropout=0.1, event_dispatcher=""): + super(ImprovedClusteredAttention, self).__init__() + self.clusters = clusters + self.iterations = iterations + self.bits = bits + self.hash_bias = hash_bias + self.topk = topk + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def _create_query_groups(self, Q, query_lengths): + N, H, L, E = Q.shape + + # Compute the hashes for all the queries + planes = Q.new_empty((self.bits, E+1)) + normal_(planes) + if not self.hash_bias: + planes[:, -1] = 0 + hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L) + + # Cluster the hashes and return the cluster index per query + clusters, counts = cluster( + hashes, + query_lengths._lengths.int(), + clusters=self.clusters, + iterations=self.iterations, + bits=self.bits + ) + sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1) + return (sorted_clusters, counts), sorted_indx + + def _topk_attention(self, Q, K, V, + clusters, counts, + topk, topk_values, + A_bottomk, softmax_temp, + query_lengths): + """Return the attention with just the topk heads.""" + # Extract some indices + N, H, L, E = Q.shape + _, _, S, _ = K.shape + _, _, C, k = topk.shape + + # We need to pass the output tensor to initialize to 0 + QK = clustered_sparse_dot_product( + Q, K, topk, + clusters, counts, + query_lengths._lengths.int() + ) + # We need to mask the topk dot products if topk > input_length + QK = QK.masked_fill( + torch.isinf(topk_values[:,0,0,:]).view(N, 1, 1, k), + float("-inf") + ) + A = torch.softmax(softmax_temp * QK, dim=-1) + assert A_bottomk.is_contiguous() + A_bottomk = clustered_broadcast( + A_bottomk.unsqueeze(3), + clusters, + counts, + torch.ones_like(counts, dtype=torch.float32) + ) + A = A * (1.0 - A_bottomk) + A = self.dropout(A) + assert A.is_contiguous() + V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts) + + return V_new + + def _broadcast_values(self, V, clusters, counts, lengths): + """Broadcast the values back to the correct positions but make sure + that the gradient flows properly.""" + V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths) + return V_new + + def _bottomk_attention(self, QK, V, clusters, counts, query_lengths, topk, softmax_temp): + """Return the attention with just the bottomk keys.""" + N, H, C, S = QK.shape + + A = torch.softmax(softmax_temp * QK, dim=-1) + mask = QK.new_ones(QK.shape) + mask[ + torch.arange(N, device=QK.device).view(N, 1, 1, 1), + torch.arange(H, device=QK.device).view(1, H, 1, 1), + torch.arange(C, device=QK.device).view(1, 1, C, 1), + topk, + ] = 0 + A = A * mask + A_bottomk = A.sum(-1) + A = self.dropout(A) + # Compute the values + V_new = torch.einsum("nhls,nhse->nhle", A, V) + # Broadcast the values back depending on the groups + V_new = self._broadcast_values(V_new, clusters, counts, query_lengths._lengths.int()) + + return V_new, A_bottomk + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Make sure that there is no attention mask + assert attn_mask.all_ones, ("Improved-clustered attention cannot " + "use an arbitrary attention mask.") + + queries = queries.permute(0,2,1,3).contiguous() + keys = keys.permute(0,2,1,3).contiguous() + values = values.permute(0,2,1,3).contiguous() + N, H, L, E = queries.shape + _, _, S, D = values.shape + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Cluster the queries into groups + groups, sorted_indx = self._create_query_groups(queries, query_lengths) + clusters, counts = groups + + # Re-organize queries so that first group belong to first cluster + # next to second cluster and so on. This improves kernel implementations. + # Note that this step is introduced after NeurIPS submission and + # now the complexity is O(N log(N)). + q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L + q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1) + s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E) + + # Aggregate the re-arranged queries. + Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int()) + # Compute the attention + QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys) + QK = QK + key_lengths.additive_matrix[:, None, None, :] + topk_values, topk = torch.topk(QK, min(self.topk, S), sorted=False, dim=-1) + assert topk.is_contiguous() + + # Now compute the attention with only the bottom keys + V_bottomk, A_bottomk = self._bottomk_attention( + QK, values, + clusters, counts, + query_lengths, + topk, + softmax_temp + ) + + # Now compute the attention with only the top keys + V_topk = self._topk_attention( + s_queries, keys, values, + clusters, counts, + topk, topk_values, + A_bottomk, + softmax_temp, + query_lengths + ) + V_sorted_new = V_topk + V_bottomk + + # Reverse the previous mapping + sorted_rev_indx = torch.argsort(sorted_indx, dim=-1) + q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1) + V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D) + return V_new.permute(0, 2, 1, 3).contiguous() + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "improved-clustered", ImprovedClusteredAttention, + [ + ("clusters", Int), + ("iterations", Optional(Int, 10)), + ("bits", Optional(Int, 63)), + ("hash_bias", Optional(Bool, True)), + ("topk", Optional(Int, 32)), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ed84800ff9452889841da84ee0ae494e267ed2ee --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/improved_clustered_causal_attention.py @@ -0,0 +1,257 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement improved clustered causal self attention.""" + +from math import sqrt + +import torch +import torch.autograd +from torch.nn import Dropout, Module +from torch.nn.init import normal_ + +from ..attention_registry import AttentionRegistry, Optional, Float, Int, \ + Bool, EventDispatcherInstance +from ..events import EventDispatcher +from ..masking import FullMask +from ..aggregate import clustered_aggregate, clustered_broadcast +from ..clustering.hamming import cluster +from ..hashing import compute_hashes +from ..sparse_product import sparse_dot_product, sparse_weighted_average +from ..sparse_product import clustered_sparse_dot_product, \ + clustered_sparse_weighted_average + + +class _GroupQueries(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, clusters, counts, lengths): + factors = 1./counts.float() + q_grouped = clustered_aggregate(Q, clusters, factors, lengths) + ctx.save_for_backward(clusters, counts, factors) + + return q_grouped + + @staticmethod + def backward(ctx, grad_q_grouped): + clusters, counts, factors = ctx.saved_tensors + grad_q = clustered_broadcast(grad_q_grouped, clusters, counts, factors) + + return grad_q, None, None, None + + +class _BroadcastValues(torch.autograd.Function): + @staticmethod + def forward(ctx, v_grouped, clusters, counts, lengths): + factors = torch.ones_like(counts, dtype=v_grouped.dtype) + V = clustered_broadcast(v_grouped, clusters, counts, factors) + ctx.save_for_backward(clusters, counts, factors, lengths) + + return V + + @staticmethod + def backward(ctx, grad_v): + clusters, counts, factors, lengths = ctx.saved_tensors + grad_v_grouped = clustered_aggregate(grad_v, clusters, factors, lengths) + + return grad_v_grouped, None, None, None, None + + +class ImprovedClusteredCausalAttention(Module): + """ + Immproved clustered causal attention approximation by recomputing attention + for each query with the top-k keys for the corresponding cluster. + + Given the queries, keys, and values as Q, K, and V respectively, we + first cluster the queries in "C" groups and compute the "C" query centroids + Q_c. + + We now use to the centroids Q_c to identify the top-k keys with highest + dot products. + + Subsequently, for each query we compute the sparse dot product with + the corresponding top-k keys to improve the attention approximation. + + Key difference with improved clustered attention is that we only use + top-k keys with causal mask, we do not compute attention on the + bottom-k keys. + + Arguments + --------- + clusters: How many clusters to group the queries into + iterations: The number of lloyd iterations to perform (default: 10) + bits: How many bits to use for the hash (default: 32) + hash_bias: If true, hamming distance proportional to L2 distance + If false, hamming distance proportional to cosine distance + (default: True) + topk: Number of top-k keys to for improved approximation (default: 32) + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, clusters, iterations=10, bits=32, + hash_bias=True, topk=32, softmax_temp=None, + attention_dropout=0.1, event_dispatcher=""): + super(ImprovedClusteredCausalAttention, self).__init__() + self.clusters = clusters + self.iterations = iterations + self.bits = bits + self.hash_bias = hash_bias + self.topk = topk + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def _create_query_groups(self, Q, query_lengths): + N, H, L, E = Q.shape + + # Compute the hashes for all the queries + planes = Q.new_empty((self.bits, E+1)) + normal_(planes) + if not self.hash_bias: + planes[:, -1] = 0 + hashes = compute_hashes(Q.view(N*H*L, E), planes).view(N, H, L) + + # Cluster the hashes and return the cluster index per query + clusters, counts = cluster( + hashes, + query_lengths.lengths.int(), + clusters=self.clusters, + iterations=self.iterations, + bits=self.bits + ) + sorted_clusters, sorted_indx = torch.sort(clusters, dim=-1) + return (sorted_clusters, counts), sorted_indx + + def _topk_attention(self, Q, K, V, + q_flat, q_rev_flat, + clusters, counts, + topk, topk_values, + softmax_temp, + query_lengths): + """Return the attention with just the topk heads.""" + # Extract some indices + N, H, L, E = Q.shape + _, _, S, _ = K.shape + _, _, C, k = topk.shape + + # We need to pass the output tensor to initialize to 0 + QK = clustered_sparse_dot_product( + Q, K, topk, + clusters, counts, + query_lengths.lengths.int() + ) + # We need to mask out the future + assert topk.is_contiguous() + topk_broadcast = clustered_broadcast( + topk.float(), + clusters, + counts, + torch.ones_like(counts, dtype=torch.float32) + ) + # Need to be careful here we changed the order of the keys the + # masking on future needs to be applied in the same way + seq_ids = torch.arange(L, device=QK.device).view(1, 1, L, 1).repeat(N, H, 1, 1) + # permute the ids in the same way as input so as to mask the right + # entries for each query + s_seq_ids = seq_ids.reshape(-1, 1).index_select(0, q_flat).view(N,H,L,1) + future_mask = topk_broadcast.long() > s_seq_ids + QK = QK.masked_fill( + future_mask, + float("-1e7") + ) + A = torch.softmax(softmax_temp * QK, dim=-1) + # Mask again to ensure no probabilities leak due to float(-1e7) + # Leakage could be very high as we use a small top-k + A = A * (1. - future_mask.float()) + A = self.dropout(A) + assert A.is_contiguous() + V_new = clustered_sparse_weighted_average(A, V, topk, clusters, counts) + + return V_new + + def _broadcast_values(self, V, clusters, counts, lengths): + """Broadcast the values back to the correct positions but make sure + that the gradient flows properly.""" + V_new = _BroadcastValues.apply(V.contiguous(), clusters, counts, lengths) + return V_new + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + + # Apply the key padding mask and make sure the attn_mask is a + # lower triangular causal mask + if not attn_mask.lower_triangular: + raise RuntimeError(("ImprovedClusteredCausalAttention only supports " + "lower triangular masks")) + queries = queries.permute(0,2,1,3).contiguous() + keys = keys.permute(0,2,1,3).contiguous() + values = values.permute(0,2,1,3).contiguous() + N, H, L, E = queries.shape + _, _, S, D = values.shape + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Cluster the queries into groups + groups, sorted_indx = self._create_query_groups(queries, query_lengths) + clusters, counts = groups + + # Re-organize queries so that first group belong to first cluster + # next to second cluster and so on. This improves kernel implementations. + # Note that this step is introduced after NeurIPS submission and + # now the complexity is O(N log(N)). + q_offset = torch.arange(N*H, device=queries.device).unsqueeze(-1) * L + q_flat = (sorted_indx.view(N*H, -1) + q_offset).reshape(-1) + s_queries = queries.reshape(-1, E).index_select(0, q_flat).view(N,H,L,E) + + # Aggregate the re-arranged queries. + Q_grouped = _GroupQueries.apply(s_queries, *groups, query_lengths.lengths.int()) + # Compute the attention + QK = torch.einsum("nhle,nhse->nhls", Q_grouped, keys) + QK = QK + key_lengths.additive_matrix[:, None, None, :] + # Set topk to minimum of key lengths if it is smaller than self.topk + cur_topk = min(self.topk, min(key_lengths.lengths).item()) + topk_values, topk = torch.topk(QK, cur_topk, sorted=False, dim=-1) + assert topk.is_contiguous() + + # Reverse mapping + sorted_rev_indx = torch.argsort(sorted_indx, dim=-1) + q_rev_flat = (sorted_rev_indx.view(N*H, -1) + q_offset).reshape(-1) + + # Compute the attention with only the top keys + V_topk = self._topk_attention( + s_queries, keys, values, + q_flat, q_rev_flat, + clusters, counts, + topk, topk_values, + softmax_temp, + query_lengths + ) + V_sorted_new = V_topk + + # Reverse the mapping to get correct values + V_new = V_sorted_new.reshape(-1, D).index_select(0, q_rev_flat).view(N,H,L,D) + return V_new.permute(0, 2, 1, 3).contiguous() + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "causal-improved-clustered", ImprovedClusteredCausalAttention, + [ + ("clusters", Int), + ("iterations", Optional(Int, 10)), + ("bits", Optional(Int, 63)), + ("hash_bias", Optional(Bool, True)), + ("topk", Optional(Int, 32)), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..dca2898fdde0b17d44d1856f727e8b4216916eee --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/linear_attention.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement unmasked linear attention.""" + +import torch +from torch.nn import Module + +from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \ + EventDispatcherInstance +from ..events import EventDispatcher +from ..feature_maps import elu_feature_map + + +class LinearAttention(Module): + """Implement unmasked attention using dot product of feature maps in + O(N D^2) complexity. + + Given the queries, keys and values as Q, K, V instead of computing + + V' = softmax(Q.mm(K.t()), dim=-1).mm(V), + + we make use of a feature map function Φ(.) and perform the following + computation + + V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V). + + The above can be computed in O(N D^2) complexity where D is the + dimensionality of Q, K and V and N is the sequence length. Depending on the + feature map, however, the complexity of the attention might be limited. + + Arguments + --------- + feature_map: callable, a callable that applies the feature map to the + last dimension of a tensor (default: elu(x)+1) + eps: float, a small number to ensure the numerical stability of the + denominator (default: 1e-6) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, query_dimensions, feature_map=None, eps=1e-6, + event_dispatcher=""): + super(LinearAttention, self).__init__() + self.feature_map = ( + feature_map(query_dimensions) if feature_map else + elu_feature_map(query_dimensions) + ) + self.eps = eps + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Apply the feature map to the queries and keys + self.feature_map.new_feature_map(queries.device) + Q = self.feature_map.forward_queries(queries) + K = self.feature_map.forward_keys(keys) + + # Apply the key padding mask and make sure that the attn_mask is + # all_ones + if not attn_mask.all_ones: + raise RuntimeError(("LinearAttention does not support arbitrary " + "attention masks")) + K = K * key_lengths.float_matrix[:, :, None, None] + + # Compute the KV matrix, namely the dot product of keys and values so + # that we never explicitly compute the attention matrix and thus + # decrease the complexity + KV = torch.einsum("nshd,nshm->nhmd", K, values) + + # Compute the normalizer + Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) + + # Finally compute and return the new values + V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) + + return V.contiguous() + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "linear", LinearAttention, + [ + ("query_dimensions", Int), + ("feature_map", Optional(Callable)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bc81747ed566e25046eced89b6fdd3fb2b244eb5 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/local_attention.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Implement local context attention.""" + +from math import sqrt + +import torch +from torch.nn import Module, Dropout +from torch.nn import functional as F + +from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ + EventDispatcherInstance +from ..events import EventDispatcher +from ..local_product import local_dot_product, local_weighted_average + + +class LocalAttention(Module): + """Implement fast local attention where a query can only attend to + neighboring keys. + + In this attention module the query Q_i can only attend to a key K_j if + |i-j| < local_context/2. + + Arguments + --------- + local_context: The neighborhood to consider for local attention. + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(LocalAttention, self).__init__() + self.local_context = local_context + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + """Implements the local attention. + + The attn_mask can be anything but the only values that will be + considered will be the ones in the neighborhood of each query. + + Arguments + --------- + queries: (N, L, H, E) The tensor containing the queries + keys: (N, S, H, E) The tensor containing the keys + values: (N, S, H, D) The tensor containing the values + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + query_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + key_lengths: An implementation of BaseMask that encodes how + many queries each sequence in the batch consists of + """ + # Extract some shapes and compute the temperature + N, L, H, E = queries.shape + _, S, _, D = values.shape + context = self.local_context + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Permute the dimensions to NHLE instead of NLHE + queries = queries.permute(0, 2, 1, 3).contiguous() + keys = keys.permute(0, 2, 1, 3).contiguous() + values = values.permute(0, 2, 1, 3).contiguous() + + QK = local_dot_product( + queries, + keys, + attn_mask.additive_matrix_finite, + key_lengths.lengths, + self.local_context + ) + A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) + + V_new = local_weighted_average(A, values) + + return V_new.permute(0, 2, 1, 3).contiguous() + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "local", LocalAttention, + [ + ("local_context", Int), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..649f223466f559943ba878c569e3de1ebabf9762 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention/reformer_attention.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement the Reformer attention from the paper +"Reformer the efficient transformer".""" + +from math import sqrt + +import torch +from torch.nn import Dropout, Module +from torch.nn.init import normal_ + +from ..attention_registry import AttentionRegistry, Optional, Int, Float, \ + Bool, EventDispatcherInstance +from ..events import EventDispatcher +from ..masking import FullMask + + +class ReformerAttention(Module): + """Implement the attention module of the paper "Reformer the efficient + transformer" + + Arguments + --------- + chunk_size : Chunk size for each block (default: 32) + bits : Number of bits for hashing (default: 8) + rounds : Number of rounds of attention computation (default: 4) + masked : If true, the query does not attend to itsself (default: False) + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + + def __init__(self, chunk_size=32, bits=8, rounds=4, masked=False, + softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(ReformerAttention, self).__init__() + + self.chunk_size = chunk_size + self.bits = bits + self.rounds = rounds + self.masked = masked + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def _normalize(self, x): + norms = torch.sqrt(torch.einsum("nlhe,nlhe->nlh", x, x)) + x_normed = x / norms.unsqueeze(-1) + return x_normed + + def _look_back(self, x): + xshape = x.shape + + return torch.cat([ + x.new_zeros((xshape[0], 1) + xshape[2:]), + torch.repeat_interleave(x, 2, dim=1)[:,:-1] + ], dim=1).view(xshape[0], xshape[1], 2*xshape[2], *xshape[3:]) + + def _reformer_round(self, Q, K, V, mask, softmax_temp): + # Hash the queries + N, L, H, E = Q.shape + planes = Q.new_empty(self.bits, E) + normal_(planes) + projected = torch.einsum("nlhe,be->nlhb", K, planes) + hashes = torch.argmax( + torch.cat([projected, -projected], dim=-1), + dim=-1 + ) + + # Sort the queries in order to group them + group = torch.argsort(hashes, dim=1) + + invert_group = torch.empty_like(group) + batch_indices = torch.arange(N, device=hashes.device).view(N, 1, 1) + sequence_indices = torch.arange(L, device=hashes.device).view(1, L, 1) + head_indices = torch.arange(H, device=hashes.device).view(1, 1, H) + invert_group[batch_indices, group, head_indices] = sequence_indices + group = group.view(N, -1, self.chunk_size, H) + invert_group = invert_group.view(N, -1, self.chunk_size, H) + batch_indices = batch_indices.unsqueeze(1) + head_indices = head_indices.unsqueeze(0) + + # Reorder Q, V and mask + Q_grouped = Q[batch_indices, group, head_indices] + K_grouped = K[batch_indices, group, head_indices] + V_grouped = V[batch_indices, group, head_indices] + mask_grouped = mask[ + batch_indices.unsqueeze(1), + group.unsqueeze(3), + self._look_back(group).unsqueeze(2) + ] + + mask_grouped[:, 0, :, :Q_grouped.shape[2]] = float("-inf") + + # When everything is masked just unmask everything because it doesn't + # matter what the output is at those positions + # This is to avoid inf/nans in the new values at masked positions + infmask = torch.isinf(mask_grouped) + infmask = torch.all(infmask, dim=3, keepdims=True) + mask_grouped = mask_grouped.masked_fill(infmask, 0.) + + # Attention + K_grouped = self._look_back(K_grouped) + QQ = torch.einsum("nblhe,nbshe->nbhls", Q_grouped, K_grouped) + QQ = QQ + mask_grouped.permute(0, 1, 4, 2, 3) + A = torch.softmax(softmax_temp * QQ, dim=-1) + A = self.dropout(A) + + # Values + V_grouped = self._look_back(V_grouped) + V_new = torch.einsum("nbhls,nbshe->nblhe", A, V_grouped) + V_new = V_new.contiguous().view(N, -1, H, E) + V_new = V_new[batch_indices, invert_group, head_indices] + V_new = V_new.contiguous().view(N, L, H, E) + return V_new + + def forward(self, queries, keys, values, attn_mask, query_lengths, + key_lengths): + # Extract the dimensions of query, key, value + N, L, H, E = queries.shape + + softmax_temp = self.softmax_temp or 1./sqrt(E) + # Create the mask + mask = key_lengths.additive_matrix.unsqueeze(1).expand(N, L, L) + if self.masked: + mask = mask + torch.eye(L, device=queries.device).unsqueeze(0)*float(-1e9) + + if not attn_mask.all_ones: + mask = mask + attn_mask.additive_matrix.unsqueeze(0) + # Get normalized Queries as Keys + K = self._normalize(queries) + # Zero the masked out keys + K = K * key_lengths.float_matrix.view(N, L, 1, 1) + + V_new = 0 + factor = 1/self.rounds + for i in range(self.rounds): + V_new = V_new + \ + factor * self._reformer_round(queries, K, values, mask, softmax_temp) + + return V_new + + +# Register the attention implementation so that it becomes available in our +# builders +AttentionRegistry.register( + "reformer", ReformerAttention, + [ + ("chunk_size", Optional(Int, 32)), + ("bits", Optional(Int, 63)), + ("rounds", Optional(Int, 4)), + ("masked", Optional(Bool, False)), + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da4c16f5882d09469b9649565a73942e9047c099 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Allow for the dynamic registration of new attention implementations. + +This module provides a Registry implementation that other modules can use to +register attention implementations for the builders. +""" + +from .registry import \ + AttentionRegistry, \ + RecurrentAttentionRegistry, \ + RecurrentCrossAttentionRegistry +from .spec import Spec, Choice, Optional, Int, Float, Bool, Callable, \ + EventDispatcherInstance diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b6b8fc6f91e9279f932f2d9d4adea1174524513 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1418b46ec4f552946ba9eadfb8d4ec817d9aa8e6 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/registry.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..357ca1dfcb0cbffe9ec3c3ca32ee793ce958cbe4 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/__pycache__/spec.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bad90e86382a0d077b31b3546f0e230440cc3761 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/registry.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + + +class Registry(object): + """Hold the available attention implementations and their required + parameters.""" + def __init__(self): + self._classes = {} + self._class_params = {} + self._parameters = {} + + def register(self, key, class_object, parameter_tuples): + # register the class if the key is new + if key in self._classes: + raise ValueError("{} is already registered".format(key)) + self._classes[key] = class_object + + # register the parameters + for parameter, spec in parameter_tuples: + if ( + parameter in self._parameters and + self._parameters[parameter] != spec + ): + raise ValueError(("{} is already registered with " + "spec {!r} instead of {!r}").format( + parameter, + self._parameters[parameter], + spec + )) + self._parameters[parameter] = spec + + # note which parameters are needed by this class + self._class_params[key] = [p for p, s in parameter_tuples] + + def __contains__(self, key): + return key in self._classes + + def __getitem__(self, key): + return self._classes[key], self._class_params[key] + + @property + def keys(self): + return list(self._classes.keys()) + + def contains_parameter(self, key): + return key in self._parameters + + def validate_parameter(self, key, value): + try: + return self._parameters[key].get(value) + except Exception as e: + raise ValueError(("Invalid value {!r} for " + "parameter {!r}").format(value, key)) from e + + +AttentionRegistry = Registry() +RecurrentAttentionRegistry = Registry() +RecurrentCrossAttentionRegistry = Registry() diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..82c6b5c0f181224fc01536227d0c32f778ab2c5e --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/attention_registry/spec.py @@ -0,0 +1,126 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Spec instances allow to describe and check the type and value of +parameters.""" + +from ..events import EventDispatcher + + +class Spec(object): + """Describe and validate a parameter type. + + Arguments + --------- + predicate: A callable that checks if the value is acceptable and + returns its canonical value or raises ValueError. + name: A name to create a human readable description of the Spec + """ + def __init__(self, predicate, name="CustomSpec"): + self._predicate = predicate + self._name = name + + def __repr__(self): + return self._name + + def check(self, x): + try: + self._predicate(x) + return True + except ValueError: + return False + + def get(self, x): + return self._predicate(x) + + def __eq__(self, y): + return self is y + + +class Choice(Spec): + """A parameter type for a set of options. + + Arguments + --------- + choices: A set or list of possible values for this parameter + """ + def __init__(self, choices): + self._choices = choices + + def get(self, x): + if x in self._choices: + return x + raise ValueError("{!r} is not in {!r}".format(x, self._choices)) + + def __repr__(self): + return "Choice({!r})".format(self._choices) + + def __eq__(self, x): + if isinstance(x, Choice): + return self._choices == x._choices + return False + + +class _Callable(Spec): + def __init__(self): + super(_Callable, self).__init__(None, "Callable") + + def get(self, x): + if callable(x): + return x + raise ValueError("{!r} is not a callable".format(x)) + + +class _EventDispatcherInstance(Spec): + def __init__(self): + super(_EventDispatcherInstance, self).__init__( + _EventDispatcherInstance._get_event_dispatcher, + "EventDispatcherInstance" + ) + + @staticmethod + def _get_event_dispatcher(x): + if isinstance(x, str): + return x + if isinstance(x, EventDispatcher): + return x + raise ValueError("{!r} is not an event dispatcher".format(x)) + + +class Optional(Spec): + """Represent an optional parameter that can either have a value or it can + be None. + + Arguments + --------- + spec: The spec for the value if it is not None + default: The returned value in case it is None + """ + def __init__(self, spec, default=None): + self._other_spec = spec + self._default = default + + def __repr__(self): + return "Optional[{!r}, {!r}]".format(self._other_spec, self._default) + + def get(self, x): + if x is None: + return self._default + return self._other_spec.get(x) + + def __eq__(self, x): + if isinstance(x, Optional): + return ( + self._other_spec == x._other_spec and + self._default == x._default + ) + return False + + +Int = Spec(int, "Int") +Float = Spec(float, "Float") +Bool = Spec(bool, "Bool") +Callable = _Callable() +EventDispatcherInstance = _EventDispatcherInstance() diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea03845a2296f81c5e829c914e5fd46a6a23777c --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__init__.py @@ -0,0 +1,59 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""This module implements builders that simplify building complex transformer +architectures with different attention mechanisms. + +The main idea is to facilitate the construction of various attention layers and +transformer encoder layers and simplify their assembly into one transformer +module. It also allows for flexibility in the scripts as many builder +parameters can correspond 1-1 with command line arguments. + +Example usage: + + builder = TransformerEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() +""" + +__all__ = [ + "AttentionBuilder", + "RecurrentAttentionBuilder", + "RecurrentCrossAttentionBuilder" +] + +# Import the attention implementations so that they register themselves with +# the builder. Attention implementations external to the library should be +# imported before using the builders. +# +# TODO: Should this behaviour change? Namely, should all attention +# implementations be imported in order to be useable? This also allows +# using the library even partially built, for instance. +from ..attention import \ + FullAttention, \ + LinearAttention + +del FullAttention, \ + LinearAttention + + +from .attention_builders import \ + AttentionBuilder, \ + RecurrentAttentionBuilder, \ + RecurrentCrossAttentionBuilder + +from .transformer_builders import \ + TransformerEncoderBuilder, \ + RecurrentEncoderBuilder, \ + TransformerDecoderBuilder, \ + RecurrentDecoderBuilder diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479b3a17009febf6840e1f6dbfb577b24ca62d79 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..931d172bf605ce21a026a701cae64e7161ce9e7a Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/attention_builders.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1df851193737468700e484b1c54b87d99285014 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/base.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3483a8e6efce5028cab06c259977c2e6b4ca13a4 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/builders/__pycache__/transformer_builders.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py b/smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb820332d795c8c0cc3338709c7adae51156901 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/builders/attention_builders.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +from collections import defaultdict + +from .base import BaseBuilder +from ..attention_registry import \ + AttentionRegistry, \ + RecurrentAttentionRegistry, \ + RecurrentCrossAttentionRegistry + + +class BaseAttentionBuilder(BaseBuilder): + def __init__(self, registry): + self._registry = registry + self._parameters = defaultdict(lambda: None) + + @property + def available_attentions(self): + """Return a list with the available attention implementations.""" + return self._registry.keys + + def validate_attention_type(self, attention_type): + """Parse the attention type according to the rules used by `get()` and + check if the requested attention is constructible.""" + return all( + all(t in self._registry for t in a.split(",")) + for a in attention_type.split(":") + ) + + def __setattr__(self, key, value): + # Make sure we have normal behaviour for the class members _registry + # and _parameters + if key in ["_registry", "_parameters"]: + return object.__setattr__(self, key, value) + + # Assign everything else in the parameters dictionary + if not self._registry.contains_parameter(key): + raise AttributeError(("{!r} is not a valid attention " + "parameter name").format(key)) + self._parameters[key] = self._registry.validate_parameter(key, value) + + def __getattr__(self, key): + if key in self._parameters: + return self._parameters[key] + else: + raise AttributeError() + + def __repr__(self): + return ( + "{}.from_kwargs(\n".format(self.__class__.__name__) + + "\n".join([" {}={!r},".format(k, v) + for k, v in self._parameters.items()])[:-1] + + "\n)" + ) + + def get(self, attention_type): + """Construct the attention implementation object and return it. + + The passed in attention_type argument defines the attention to be + created. It should be a string and in its simplest form it should + be one of the available choices from `available_attentions`. + + However, to enable attention decoration, namely an attention + implementation augmenting the functionality of another implementation, + the attention type can be a colon separated list of compositions like + the following examples: + + - 'att1' means instantiate att1 + - 'att2:att1' means instantiate att1 and decorate it with att2 + - 'att3:att1,att4' means instantiate att1 and att4 and decorate + them with att3 + + Arguments + --------- + attention_type: A string that contains one or more keys from + `available_attentions` separated with a colon to + denote the decoration pattern. + """ + compositions = reversed(attention_type.split(":")) + attentions = [] + for c in compositions: + attentions = [ + self._construct_attention(t, attentions) + for t in c.split(",") + ] + if len(attentions) > 1: + raise ValueError(("Invalid attention_type argument " + "{!r}").format(attention_type)) + return attentions[0] + + def _construct_attention(self, attention_type, decorated=[]): + """Construct an attention implementation object. + + Arguments + --------- + attention_type: A string that contains a single key from the + `available_attentions` + decorated: A list of attention implementations to pass as arguments + to be decorated + """ + if attention_type not in self._registry: + raise ValueError(("Unknown attention type " + "{!r}").format(attention_type)) + + attention, parameters = self._registry[attention_type] + parameter_dictionary = { + p: self._registry.validate_parameter(p, self._parameters[p]) + for p in parameters + } + + return attention(*decorated, **parameter_dictionary) + + +class AttentionBuilder(BaseAttentionBuilder): + """Build attention implementations for batch sequence processing or + training.""" + def __init__(self): + super(AttentionBuilder, self).__init__(AttentionRegistry) + + +class RecurrentAttentionBuilder(BaseAttentionBuilder): + """Build attention implementations for autoregressive sequence + processing.""" + def __init__(self): + super(RecurrentAttentionBuilder, self).__init__( + RecurrentAttentionRegistry + ) + + +class RecurrentCrossAttentionBuilder(BaseAttentionBuilder): + """Build attention implementations for autoregressive cross attention + computation.""" + def __init__(self): + super(RecurrentCrossAttentionBuilder, self).__init__( + RecurrentCrossAttentionRegistry + ) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py b/smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..da135a759eb13bc2745fb87976928bc84b9efc41 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/builders/base.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Provide a class for the others to inherit some useful functionality.""" + + +class BaseBuilder(object): + @classmethod + def from_kwargs(cls, **kwargs): + """Construct a builder and set all the keyword arguments as parameters. + + The keyword argument strict is passed to + BaseBuilder.from_dictionary separately. + + See BaseBuilder.from_dictionary(). + """ + strict = kwargs.pop("strict", True) + return cls.from_dictionary(kwargs, strict=strict) + + @classmethod + def from_namespace(cls, args, strict=False): + """Construct a builder from an argparse Namespace. + + To be used for building transformers from command line arguments. + + See BaseBuilder.from_dictionary(). + """ + return cls.from_dictionary(vars(args), strict=strict) + + @classmethod + def from_dictionary(cls, dictionary, strict=True): + """Construct a builder and set all the parameters in the dictionary. + + Given a dictionary + + d = {"foo": "bar"} + + then + + builder = TransformerEncoderBuilder.from_dictionary(d) + + is equivalent to + + builder = TransformerEncoderBuilder() + builder.foo = "bar" + + Arguments + --------- + dictionary: A dictionary of parameters to set to the builder. + strict: bool, If a key is not a parameter and strict is set to True + then a ValueError is raised, otherwise that dictionary key + is ignored (default: True) + """ + builder = cls() + for k, v in dictionary.items(): + try: + setattr(builder, k, v) + except AttributeError: + if strict: + raise ValueError(("The builder has no " + "parameter {!r}").format(k)) + else: + continue + return builder diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py b/smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py new file mode 100644 index 0000000000000000000000000000000000000000..256404fbf582f1e3720f5156793c2bf889f7d833 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/builders/transformer_builders.py @@ -0,0 +1,550 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Build complex transformer architectures for inference or training easily.""" + +from torch.nn import LayerNorm + +from ..attention import AttentionLayer +from ..transformers import TransformerEncoder, TransformerEncoderLayer, \ + TransformerDecoder, TransformerDecoderLayer +from ..recurrent.attention import \ + RecurrentAttentionLayer, \ + RecurrentCrossAttentionLayer +from ..recurrent.transformers import \ + RecurrentTransformerEncoder, RecurrentTransformerEncoderLayer, \ + RecurrentTransformerDecoder, RecurrentTransformerDecoderLayer +from .base import BaseBuilder +from .attention_builders import AttentionBuilder, RecurrentAttentionBuilder, \ + RecurrentCrossAttentionBuilder + + +class BaseTransformerBuilder(BaseBuilder): + """Contains all the parameters for building a transformer other than the + attention part. + + Classes extending the BaseTransformerBuilder should implement the `get()` + method that actually builds the transformer. + """ + def __init__(self): + # transformer parameters + self._n_layers = 4 + self._n_heads = 4 + self._d_query = 64 + self._d_value = 64 + self._d_ff = 1024 + self._dropout = 0.1 + self._activation = "relu" + self._final_norm = True + self._event_dispatcher = "" # the default global dispatcher + + @property + def n_layers(self): + """The number of transformer layers.""" + return self._n_layers + + @n_layers.setter + def n_layers(self, val): + self._n_layers = val + + @property + def n_heads(self): + """The number of heads in each transformer layer.""" + return self._n_heads + + @n_heads.setter + def n_heads(self, val): + self._n_heads = val + + @property + def feed_forward_dimensions(self): + """The dimensions of the fully connected layer in the transformer + layers.""" + return self._d_ff + + @feed_forward_dimensions.setter + def feed_forward_dimensions(self, val): + self._d_ff = val + + @property + def query_dimensions(self): + """The dimensions of the queries and keys in each attention layer.""" + return self._d_query + + @query_dimensions.setter + def query_dimensions(self, val): + self._d_query = val + + @property + def value_dimensions(self): + """The dimensions of the values in each attention layer.""" + return self._d_value + + @value_dimensions.setter + def value_dimensions(self, val): + self._d_value = val + + @property + def dropout(self): + """The dropout rate to be applied in the transformer encoder layer.""" + return self._dropout + + @dropout.setter + def dropout(self, val): + self._dropout = val + + @property + def activation(self): + """The activation function for the transformer layer. + + One of {'relu', 'gelu'}. + """ + return self._activation + + @activation.setter + def activation(self, val): + activations = ["relu", "gelu"] + if val not in activations: + raise ValueError(("{!r} is not one of the availabel activation " + "types {!r}").format(val, activations)) + self._activation = val + + @property + def final_normalization(self): + """Whether to add LayerNorm as the final layer of the + TransformerEncoder.""" + return self._final_norm + + @final_normalization.setter + def final_normalization(self, val): + self._final_norm = bool(val) + + @property + def event_dispatcher(self): + """The transformer event dispatcher either as a string or as an + EventDispatcher object.""" + return self._event_dispatcher + + @event_dispatcher.setter + def event_dispatcher(self, event_dispatcher): + self._event_dispatcher = event_dispatcher + + def get(self): + """Build the transformer and return it.""" + raise NotImplementedError() + + +class BaseTransformerEncoderBuilder(BaseTransformerBuilder): + """Implement the logic of building a transformer encoder but leave the + specific layers open for changing by the inheriting classes. This allows us + to reuse the logic for creating both the TransformerEncoder and the + RecurrentTransformerEncoder. + + Inheriting classes should implement the following: + + - _get_attention_builder() + - _get_attention_layer_class() + - _get_encoder_class() + - _get_encoder_layer_class() + """ + def __init__(self): + super(BaseTransformerEncoderBuilder, self).__init__() + self._attention_builder = self._get_attention_builder() + self._attention_type = "full" + + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + raise NotImplementedError() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + raise NotImplementedError() + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + raise NotImplementedError() + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + raise NotImplementedError() + + @property + def attention(self): + """The attention builder instance.""" + return self._attention_builder + + @property + def attention_type(self): + """The attention implementation chosen.""" + return self._attention_type + + @attention_type.setter + def attention_type(self, val): + if not self._attention_builder.validate_attention_type(val): + raise ValueError(("{!r} is not an available attention " + "type").format(val)) + self._attention_type = val + + def __setattr__(self, key, val): + # "protected" attributes are settable (probably from withing the class) + if key[0] == "_": + return super().__setattr__(key, val) + + # Existing attributes are settable but they might also be attention + # parameters so try that as well + fail_on_exception = True + if hasattr(self, key): + super().__setattr__(key, val) + fail_on_exception = False + + # Non-existing "public" attributes may be attention parameters + try: + setattr(self._attention_builder, key, val) + except: + if fail_on_exception: + raise + + def get(self): + """Build the transformer and return it.""" + # Set the event dispatcher to the attention builder + self.attention.event_dispatcher = self.event_dispatcher + + # Extract into local variables the classes to be used + Encoder = self._get_encoder_class() + EncoderLayer = self._get_encoder_layer_class() + Attention = self._get_attention_layer_class() + + model_dimensions = self.value_dimensions*self.n_heads + return Encoder( + [ + EncoderLayer( + Attention( + self.attention.get(self.attention_type), + model_dimensions, + self.n_heads, + d_keys=self.query_dimensions, + d_values=self.value_dimensions, + event_dispatcher=self.event_dispatcher + ), + model_dimensions, + self.feed_forward_dimensions, + self.dropout, + self.activation, + event_dispatcher=self.event_dispatcher + ) + for _ in range(self.n_layers) + ], + (LayerNorm(model_dimensions) if self.final_normalization else None), + event_dispatcher=self.event_dispatcher + ) + + +class TransformerEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a batch transformer encoder for training or processing of + sequences all elements at a time. + + Example usage: + + builder = TransformerEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an instance of the appropriate attention builder.""" + return AttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the layer that projects queries keys and + values.""" + return AttentionLayer + + def _get_encoder_class(self): + """Return the class for the transformer encoder.""" + return TransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the transformer encoder layer.""" + return TransformerEncoderLayer + + +class RecurrentEncoderBuilder(BaseTransformerEncoderBuilder): + """Build a transformer encoder for autoregressive processing of sequences. + + Example usage: + + builder = RecurrentEncoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.attention_type = "linear" + transformer = builder.get() + """ + def _get_attention_builder(self): + """Return an attention builder for recurrent attention.""" + return RecurrentAttentionBuilder() + + def _get_attention_layer_class(self): + """Return the class for the recurrent layer that projects queries keys + and values.""" + return RecurrentAttentionLayer + + def _get_encoder_class(self): + """Return the class for the recurrent transformer encoder.""" + return RecurrentTransformerEncoder + + def _get_encoder_layer_class(self): + """Return the class for the recurrent transformer encoder layer.""" + return RecurrentTransformerEncoderLayer + + +class BaseTransformerDecoderBuilder(BaseTransformerBuilder): + """Similar to BaseTransformerEncoderBuilder implement the logic of + building the transformer decoder without defining concrete layers. + + Inheriting classes should implement the following: + + - _get_self_attention_builder() and _get_cross_attention_builder() + - _get_self_attention_layer_class() and _get_cross_attention_layer_class() + - _get_decoder_class() + - _get_decoder_layer_class() + """ + def __init__(self): + super(BaseTransformerDecoderBuilder, self).__init__() + self._self_attention_builder = self._get_self_attention_builder() + self._cross_attention_builder = self._get_cross_attention_builder() + self._self_attention_type = "full" + self._cross_attention_type = "full" + + def _get_self_attention_builder(self): + """Return an instance of attention builder.""" + raise NotImplementedError() + + def _get_cross_attention_builder(self): + """Return an instance of attention builder.""" + raise NotImplementedError() + + def _get_self_attention_layer_class(self): + """Return a class to project the queries, keys and values to + multi-head versions.""" + raise NotImplementedError() + + def _get_cross_attention_layer_class(self): + """Return a class to project the queries, keys and values to + multi-head versions.""" + raise NotImplementedError() + + def _get_decoder_class(self): + """Return the class for the transformer decoder.""" + raise NotImplementedError() + + def _get_decoder_layer_class(self): + """Return the class for the transformer decoder layer.""" + raise NotImplementedError() + + @property + def self_attention(self): + """The attention builder instance that will be used for the self + attention modules.""" + return self._self_attention_builder + + @property + def self_attention_type(self): + """The attention implementation used for self attention.""" + return self._self_attention_type + + @self_attention_type.setter + def self_attention_type(self, val): + if not self._self_attention_builder.validate_attention_type(val): + raise ValueError(("{!r} is not an available self attention " + "type").format(val)) + self._self_attention_type = val + + @property + def cross_attention(self): + """The attention builder instance that will be used for the cross + attention modules.""" + return self._cross_attention_builder + + @property + def cross_attention_type(self): + """The attention implementation used for cross attention.""" + return self._cross_attention_type + + @cross_attention_type.setter + def cross_attention_type(self, val): + if not self._cross_attention_builder.validate_attention_type(val): + raise ValueError(("{!r} is not an available cross attention " + "type").format(val)) + self._cross_attention_type = val + + def __setattr__(self, key, val): + # "protected" attributes are settable (probably from withing the class) + if key[0] == "_": + return super().__setattr__(key, val) + + # Existing attributes are settable but they might also be attention + # parameters so try that as well + fail_on_exception = True + if hasattr(self, key): + super().__setattr__(key, val) + fail_on_exception = False + + # Non-existing "public" attributes may be attention parameters + try: + setattr(self._self_attention_builder, key, val) + setattr(self._cross_attention_builder, key, val) + except: + if fail_on_exception: + raise + + def get(self): + """Build the transformer and return it.""" + # Set the event dispatcher to attention builders + self.self_attention.event_dispatcher = self.event_dispatcher + self.cross_attention.event_dispatcher = self.event_dispatcher + + # Extract into local variables the classes to be used + Decoder = self._get_decoder_class() + DecoderLayer = self._get_decoder_layer_class() + SelfAttention = self._get_self_attention_layer_class() + CrossAttention = self._get_cross_attention_layer_class() + + model_dimensions = self.value_dimensions*self.n_heads + return Decoder( + [ + DecoderLayer( + SelfAttention( + self.self_attention.get(self.self_attention_type), + model_dimensions, + self.n_heads, + d_keys=self.query_dimensions, + d_values=self.value_dimensions, + event_dispatcher=self.event_dispatcher + ), + CrossAttention( + self.cross_attention.get(self.cross_attention_type), + model_dimensions, + self.n_heads, + d_keys=self.query_dimensions, + d_values=self.value_dimensions, + event_dispatcher=self.event_dispatcher + ), + model_dimensions, + self.feed_forward_dimensions, + self.dropout, + self.activation, + event_dispatcher=self.event_dispatcher + ) + for _ in range(self.n_layers) + ], + (LayerNorm(model_dimensions) if self.final_normalization else None), + event_dispatcher=self.event_dispatcher + ) + + +class TransformerDecoderBuilder(BaseTransformerDecoderBuilder): + """Build a transformer decoder for training or processing of sequences all + elements at a time. + + Example usage: + + builder = TransformerDecoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.self_attention_type = "full" + builder.cross_attention_type = "full" + transformer = builder.get() + """ + def _get_self_attention_builder(self): + """Return an attention builder for creating non-recurrent attention + variants.""" + return AttentionBuilder() + + def _get_cross_attention_builder(self): + """Return an attention builder for creating non-recurrent attention + variants.""" + return AttentionBuilder() + + def _get_self_attention_layer_class(self): + """Return the non-recurrent attention layer to project queries, keys + and values.""" + return AttentionLayer + + def _get_cross_attention_layer_class(self): + """Return the non-recurrent attention layer to project queries, keys + and values.""" + return AttentionLayer + + def _get_decoder_class(self): + """Return the transformer decoder class.""" + return TransformerDecoder + + def _get_decoder_layer_class(self): + """Return the transformer decoder layer class.""" + return TransformerDecoderLayer + + +class RecurrentDecoderBuilder(BaseTransformerDecoderBuilder): + """Build a transformer decoder for processing of sequences in + autoregressive fashion. + + Example usage: + + builder = RecurrentDecoderBuilder() + builder.n_layers = 12 + builder.n_heads = 8 + builder.feed_forward_dimensions = 1024 + builder.query_dimensions = 64 + builder.value_dimensions = 64 + builder.dropout = 0.1 + builder.attention_dropout = 0.1 + builder.self_attention_type = "full" + builder.cross_attention_type = "full" + transformer = builder.get() + """ + def _get_self_attention_builder(self): + """Return an attention builder for creating non-recurrent attention + variants.""" + return RecurrentAttentionBuilder() + + def _get_cross_attention_builder(self): + """Return an attention builder for creating non-recurrent attention + variants.""" + return RecurrentCrossAttentionBuilder() + + def _get_self_attention_layer_class(self): + """Return the non-recurrent attention layer to project queries, keys + and values.""" + return RecurrentAttentionLayer + + def _get_cross_attention_layer_class(self): + """Return the non-recurrent attention layer to project queries, keys + and values.""" + return RecurrentCrossAttentionLayer + + def _get_decoder_class(self): + """Return the transformer decoder class.""" + return RecurrentTransformerDecoder + + def _get_decoder_layer_class(self): + """Return the transformer decoder layer class.""" + return RecurrentTransformerDecoderLayer diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6662bae4026f3fde704eac754492aa2e424de7a --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/causal_product/__init__.py @@ -0,0 +1,78 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +import torch + +from .causal_product_cpu import causal_dot_product as causal_dot_product_cpu, \ + causal_dot_backward as causal_dot_backward_cpu + +try: + from .causal_product_cuda import \ + causal_dot_product as causal_dot_product_cuda, \ + causal_dot_backward as causal_dot_backward_cuda +except ImportError: + causal_dot_product_cuda = causal_dot_backward_cuda = None + + +class CausalDotProduct(torch.autograd.Function): + """Compute the weighted sum of values but attending only to previous + values.""" + dot = { + "cpu": causal_dot_product_cpu, + "cuda": causal_dot_product_cuda + } + dot_backward = { + "cpu": causal_dot_backward_cpu, + "cuda": causal_dot_backward_cuda + } + + @staticmethod + def forward(ctx, Q, K, V): + # Save the inputs for the gradient computation + ctx.save_for_backward(Q, K, V) + + # Create the output tensor + device = Q.device + N, H, L, _ = Q.shape + _, _, _, M = V.shape + product = torch.zeros((N, H, L, M), device=device) + + # Actually perform the dot product + CausalDotProduct.dot[device.type]( + Q.data, + K.data, + V.data, + product + ) + + return product + + @staticmethod + def backward(ctx, grad_out): + # Extract the saved tensors + Q, K, V = ctx.saved_tensors + + # Allocate memory for the gradients + grad_Q = torch.zeros_like(Q) + grad_K = torch.zeros_like(K) + grad_V = torch.zeros_like(V) + + # Actually compute the gradients + CausalDotProduct.dot_backward[Q.device.type]( + Q.data, + K.data, + V.data, + grad_out, + grad_Q, + grad_K, + grad_V + ) + + return grad_Q, grad_K, grad_V + + +# Alias the autograd functions to python style snake case naming +causal_dot_product = CausalDotProduct.apply diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..98a4ee7fdc844ac632bcf91898a46129d1011088 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/causal_product/causal_product_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84f32370e707beebd8fee88f356fb62721096142265895a5a8e9872063c04595 +size 140928 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/clustering/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a8a0fa86f7d508160017eda4f388ff40e569d6 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/__init__.py @@ -0,0 +1,115 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + + +import numpy as np + +import torch + +from .cluster_cpu import cluster as cluster_cpu +try: + from .cluster_cuda import cluster as cluster_gpu +except ImportError: + pass + + +def cluster( + hashes, + lengths, + groups=None, + counts=None, + centroids=None, + distances=None, + bitcounts=None, + clusters=30, + iterations=10, + bits=32 +): + """Cluster hashes using a few iterations of K-Means with hamming distance. + + All the tensors default initialized to None are optional buffers to avoid + memory allocations. distances and bitcounts are only used by the CUDA + version of this call. clusters will be ignored if centroids is provided. + + Arguments + --------- + hashes: A long tensor of shape (N, H, L) containing a hashcode for each + query. + lengths: An int tensor of shape (N,) containing the sequence length for + each sequence in hashes. + groups: An int tensor buffer of shape (N, H, L) contaning the cluster + in which the corresponding hash belongs to. + counts: An int tensor buffer of shape (N, H, K) containing the number + of elements in each cluster. + centroids: A long tensor buffer of shape (N, H, K) containing the + centroid for each cluster. + distances: An int tensor of shape (N, H, L) containing the distance to + the closest centroid for each hash. + bitcounts: An int tensor of shape (N, H, K, bits) containing the number + of elements that have 1 for a given bit. + clusters: The number of clusters to use for each sequence. It is + ignored if centroids is not None. + iterations: How many k-means iterations to perform. + bits: How many of the least-significant bits in hashes to consider. + + Returns + ------- + groups and counts as defined above. + """ + device = hashes.device + N, H, L = hashes.shape + + # Unfortunately cpu and gpu have different APIs so the entire call must be + # surrounded by an if-then-else + if device.type == "cpu": + if groups is None: + groups = torch.empty((N, H, L), dtype=torch.int32) + if centroids is None: + centroids = torch.empty((N, H, clusters), dtype=torch.int64) + centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)] + K = centroids.shape[2] + if counts is None: + counts = torch.empty((N, H, K), dtype=torch.int32) + + cluster_cpu( + hashes, lengths, + centroids, groups, counts, + iterations, bits + ) + + return groups, counts + + else: + if groups is None: + groups = torch.empty((N, H, L), dtype=torch.int32, device=device) + if centroids is None: + centroids = torch.empty((N, H, clusters), dtype=torch.int64, + device=device) + centroids = hashes[:, :, np.random.choice(L, size=[clusters], replace=False)] + K = centroids.numel() // N // H + #K = clusters + if counts is None: + counts = torch.empty((N, H, K), dtype=torch.int32, device=device) + if distances is None: + distances = torch.empty((N, H, L), dtype=torch.int32, + device=device) + if bitcounts is None: + bitcounts = torch.empty((N, H, K, bits), dtype=torch.int32, + device=device) + groups = groups.view(N, H, L) + counts = counts.view(N, H, K) + centroids = centroids.view(N, H, K) + distances = distances.view(N, H, L) + bitcounts = bitcounts.view(N, H, K, -1) + + cluster_gpu( + hashes, lengths, + centroids, distances, bitcounts, groups, counts, + iterations, bits + ) + + return groups, counts + diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..8d827bb9d00d15432a66dd7573e28825c645b041 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/clustering/hamming/cluster_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2bd8f761d6e1efdeea33665cad8702b5c07d1a0db728d19cf332c4383510d45 +size 139824 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c475ba666e0f390dad8fc53d8d86e3c20497836 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/events/__init__.py @@ -0,0 +1,10 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""This module implements a basic event system that allows the transformer +internal components to make available any tensor with minimal overhead.""" + +from .event import Event, AttentionEvent, QKVEvent +from .event_dispatcher import EventDispatcher diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afd9ade88fe830d74ad22c6a3f7a9ff0319b1967 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc909bb1f411c5013ba8a32bf08f207e83a042d2 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a35cd27421d2fd85a6d593a2e3808c6262b3a4e Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/event_dispatcher.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..769256db7025d96af4620a09678cd66f566ce498 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/events/__pycache__/filters.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/event.py b/smi-ted/inference/smi_ted_light/fast_transformers/events/event.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6021b45195388b47e2491a3167b0332efebae6 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/events/event.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + + +class Event(object): + """The Event is the base class for all events that are dispatched from any + transformer module. + + This class defines only the basic attributes of an event without any + payload. + + Arguments + --------- + source: torch.nn.Module instance that dispatched this event + """ + def __init__(self, source): + self.source = source + + +class AttentionEvent(Event): + """An event containing an attention matrix. + + Arguments + --------- + source: torch.nn.Module instance that dispatched this event + attention_matrix: torch.tensor of the multihead attention matrix + computed in the corresponding attention layer + """ + def __init__(self, source, attention_matrix): + super(AttentionEvent, self).__init__(source) + self.attention_matrix = attention_matrix + + +class QKVEvent(Event): + """An event containing the queries, keys and values projected in their + multiple heads. + + Arguments + --------- + source: torch.nn.Module instance that dispatched this event + queries: torch.tensor containing the queries in shape NLHE + keys: torch.tensor containing the keys in shape NSHE + values: torch.tensor containing the values in shape NSHD + """ + def __init__(self, source, queries, keys, values): + super(QKVEvent, self).__init__(source) + self.queries = queries + self.keys = keys + self.values = values diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py b/smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..c5282f60d05205a33992084132807af21b5acc1f --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/events/event_dispatcher.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +from collections import OrderedDict + +from .event import Event +from .filters import event_class + + +class EventDispatcher(object): + """An EventDispatcher is a simple way to implement an observer pattern for + loose coupling of components. In our case it is used so that the internals + of large neural networks can communicate with the outside world in an + agnostic and efficient way. + + Example usage + ------------- + + from fast_transformers.events import EventDispatcher, AttentionEvent + from fast_transformers.events.filters import \ + layer_name_contains + + def attention_event_handler(event): + print(event.attention_matrix) + + ed = EventDispatcher() + ed.listen(AttentionEvent, attention_event_handler) + ed.listen( + AttentionEvent & layer_name_contains("layers.12"), + attention_event_handler + ) + """ + _dispatchers = {} + + def __init__(self): + self._listeners = OrderedDict() + + def listen(self, event_filter, event_handler): + """Add an event handler for the events that pass the event filter. + + Arguments + --------- + event_filter: callable or Event class to define for which events + this handler will be called + event_handler: callable that accepts an instance of Event + """ + if isinstance(event_filter, type) and issubclass(event_filter, Event): + event_filter = event_class(event_filter) + + self._listeners[event_handler] = event_filter + + def remove(self, event_handler): + """Remove the event_handler from the listeners so that no more events + are dispatched to this handler.""" + self._listeners.pop(event_handler, None) + + def clear(self): + """Remove all listeners from the event dispatcher.""" + self._listeners.clear() + + def dispatch(self, event): + """Dispatch an event to the listeners. + + Arguments + --------- + event: Event instance + """ + for event_handler, event_filter in self._listeners.items(): + if event_filter(event): + event_handler(event) + + @classmethod + def get(cls, key=""): + """Factory method for creating global event dispatchers for loosely + coupling parts of a larger codebase. + + Since global objects are a complete antipattern, we suggest that this + is only used to set a default value for an event dispatcher passed as + an argument. + + Argument + -------- + key: A key to uniquely identify a dispatcher or an instance of a + dispatcher to be returned as is + """ + if isinstance(key, cls): + return key + if key not in cls._dispatchers: + cls._dispatchers[key] = cls() + return cls._dispatchers[key] diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py b/smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..d843d349c66e699796c60485dc507d02f234278f --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/events/filters.py @@ -0,0 +1,141 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Define composable functions to filter events.""" + +import weakref + +from .event import Event + + +class EventFilter(object): + """EventFilter instances are predicates (ie functions that return True or + False) to be used with an event dispatcher for filtering event + instances. + + The main benefit from using raw functions is that an EventFilter composes + very easily using operators such as &, |, ~. + + Example + -------- + + event_filter = AttentionEvent | layer_name_contains("layers.1") + event_filter = from_layer(transformer.layers[2].attention) + event_filter = ( + AttentionEvent & + lambda ev: torch.isnan(ev.attention_matrix).any() + ) + """ + def __call__(self, event): + raise NotImplementedError() + + def _to_event_filter(self, other): + if isinstance(other, EventFilter): + return other + if isinstance(other, type) and issubclass(other, Event): + return event_class(other) + if callable(other): + return CallableEventFilter(other) + + return NotImplemented + + def __and__(self, other): + other = self._to_event_filter(other) + if other is NotImplemented: + return other + return CallableEventFilter(lambda ev: self(ev) and other(ev)) + + def __rand__(self, other): + other = self._to_event_filter(other) + if other is NotImplemented: + return other + return CallableEventFilter(lambda ev: other(ev) and self(ev)) + + def __or__(self, other): + other = self._to_event_filter(other) + if other is NotImplemented: + return other + return CallableEventFilter(lambda ev: self(ev) or other(ev)) + + def __ror__(self, other): + other = self._to_event_filter(other) + if other is NotImplemented: + return other + return CallableEventFilter(lambda ev: other(ev) or self(ev)) + + def __invert__(self): + return CallableEventFilter(lambda ev: not self(ev)) + + +class CallableEventFilter(EventFilter): + """Wrap a function with an EventFilter object.""" + def __init__(self, event_filter): + self._event_filter = event_filter + + def __call__(self, event): + return self._event_filter(event) + + +class LayerNameEventFilter(EventFilter): + """A LayerNameEventFilter allows to filter events based on a human readable + name of the layer that emitted them. + + Note that LayerNameEventFilter keeps a weak reference to all modules which + means that it cannot be used to prevent modules from being garbage + collected. + + Arguments + --------- + root: torch.nn.Module instance that represents the root container + name_filter: callable, that returns true if the name + """ + def __init__(self, root, name_filter): + self._names = { + weakref.ref(m): n + for n, m in root.named_modules() + } + self._name_filter = name_filter + + def __call__(self, event): + name = self._names.get(weakref.ref(event.source), None) + if name is None: + return False + return self._name_filter(name) + + +def event_class(klass): + """Select events that are instances of `klass`. + + Arguments + --------- + klass: A class to check the event instance against + + Returns + ------- + An instance of EventFilter + """ + return CallableEventFilter(lambda ev: isinstance(ev, klass)) + + +def from_layer(layer): + """Select events that are dispatched from the `layer`. + + Arguments + --------- + layer: An instance of torch.nn.Module to check against the event source + + Returns + ------- + An instance of EventFilter + """ + return CallableEventFilter(lambda ev: ev.source is layer) + + +def layer_name_contains(root, name): + """Select events that contain `name` in their human readable name. + + We use root.named_modules() to get human readable names for the layers. + """ + return LayerNameEventFilter(root, lambda n: name in n) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4264cbee4c8bb23148d3e63cc9e9e70ad0af4d --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__init__.py @@ -0,0 +1,12 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Implementations of feature maps to be used with linear attention and causal +linear attention.""" + + +from .base import elu_feature_map, ActivationFunctionFeatureMap +from .fourier_features import RandomFourierFeatures, Favor, \ + SmoothedRandomFourierFeatures, GeneralizedRandomFeatures diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f1c6d02ca0bb38867ebc16f68fb0f614fe86d72 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2028ec40459322d210c15ff8cc3e041c703abb6f Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/base.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/fourier_features.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/fourier_features.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b55b34a5362349786641b6c7e1384c8e4af73d7 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/__pycache__/fourier_features.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/base.py b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0abc562a92dcbc58a2933a55aad92d00f271d907 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/base.py @@ -0,0 +1,73 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Create the feature map interface and some commonly used feature maps. + +All attention implementations that expect a feature map shall receive a factory +function that returns a feature map instance when called with the query +dimensions. +""" + +from functools import partial + +import torch +from torch.nn import Module + + +class FeatureMap(Module): + """Define the FeatureMap interface.""" + def __init__(self, query_dims): + super().__init__() + self.query_dims = query_dims + + def new_feature_map(self, device): + """Create a new instance of this feature map. In particular, if it is a + random feature map sample new parameters.""" + raise NotImplementedError() + + def forward_queries(self, x): + """Encode the queries `x` using this feature map.""" + return self(x) + + def forward_keys(self, x): + """Encode the keys `x` using this feature map.""" + return self(x) + + def forward(self, x): + """Encode x using this feature map. For symmetric feature maps it + suffices to define this function, but for asymmetric feature maps one + needs to define the `forward_queries` and `forward_keys` functions.""" + raise NotImplementedError() + + @classmethod + def factory(cls, *args, **kwargs): + """Return a function that when called with the query dimensions returns + an instance of this feature map. + + It is inherited by the subclasses so it is available in all feature + maps. + """ + def inner(query_dims): + return cls(query_dims, *args, **kwargs) + return inner + + +class ActivationFunctionFeatureMap(FeatureMap): + """Define a feature map that is simply an element-wise activation + function.""" + def __init__(self, query_dims, activation_function): + super().__init__(query_dims) + self.activation_function = activation_function + + def new_feature_map(self, device): + return + + def forward(self, x): + return self.activation_function(x) + + +elu_feature_map = ActivationFunctionFeatureMap.factory( + lambda x: torch.nn.functional.elu(x) + 1 +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/fourier_features.py b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/fourier_features.py new file mode 100644 index 0000000000000000000000000000000000000000..29446ac861b36af8455ceea50dec39b8b10d1c04 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/feature_maps/fourier_features.py @@ -0,0 +1,287 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Implement the positive orthogonal random features from the paper +"Rethinking Attention with Performers" https://arxiv.org/pdf/2009.14794.pdf +and the traditional random Fourier features that approximate the RBF kernel. +""" + +from math import sqrt, log +import warnings + +import torch + +from .base import FeatureMap + + +def orthogonal_random_matrix_(w): + """Initialize the matrix w in-place to compute orthogonal random features. + + The matrix is initialized such that its columns are orthogonal to each + other (in groups of size `rows`) and their norms is drawn from the + chi-square distribution with `rows` degrees of freedom (namely the norm of + a `rows`-dimensional vector distributed as N(0, I)). + + Arguments + --------- + w: float tensor of size (rows, columns) + """ + rows, columns = w.shape + start = 0 + while start < columns: + end = min(start+rows, columns) + block = torch.randn(rows, rows, device=w.device) + norms = torch.sqrt(torch.einsum("ab,ab->a", block, block)) + Q, _ = torch.qr(block) + w[:, start:end] = ( + Q[:, :end-start] * norms[None, :end-start] + ) + start += rows + + +class RandomFourierFeatures(FeatureMap): + """Random Fourier Features for the RBF kernel according to [1]. + + [1]: "Weighted Sums of Random Kitchen Sinks: Replacing minimization with + randomization in learning" by A. Rahimi and Benjamin Recht. + + Arguments + --------- + query_dimensions: int, The input query dimensions in order to sample + the noise matrix + n_dims: int, The size of the feature map (should be divisible by 2) + (default: query_dimensions) + softmax_temp: float, The temerature for the Gaussian kernel + approximation exp(-t * |x-y|^2) + (default: 1/sqrt(query_dimensions)) + orthogonal: bool, When True the random matrix is initialized for + orthogonal random features to reduce the approximation + variance (default: False) + redraw: int, Redraw the random matrix every 'redraw' times + (default: 1) + deterministic_eval: bool, Only redraw the random matrix during training + (default: False) + """ + def __init__(self, query_dimensions, n_dims=None, softmax_temp=None, + orthogonal=False, redraw=1, deterministic_eval=False): + super(RandomFourierFeatures, self).__init__(query_dimensions) + + self.n_dims = n_dims or query_dimensions + self.query_dimensions = query_dimensions + self.orthogonal = orthogonal + self.softmax_temp = ( + 1/sqrt(query_dimensions) if softmax_temp is None + else softmax_temp + ) + self.redraw = redraw + self.deterministic_eval = deterministic_eval + + # Make a buffer for storing the sampled omega + self.register_buffer( + "omega", + torch.zeros(self.query_dimensions, self.n_dims//2) + ) + self._calls = -1 + + def new_feature_map(self, device): + # If we are not training skip the generation of a new feature map + if self.deterministic_eval and not self.training: + return + + # Only redraw the new feature map every self.redraw times + self._calls += 1 + if (self._calls % self.redraw) != 0: + return + + omega = torch.zeros( + self.query_dimensions, + self.n_dims//2, + device=device + ) + if self.orthogonal: + orthogonal_random_matrix_(omega) + else: + omega.normal_() + self.register_buffer("omega", omega) + + def forward(self, x): + x = x * sqrt(self.softmax_temp) + u = x.unsqueeze(-2).matmul(self.omega).squeeze(-2) + phi = torch.cat([torch.cos(u), torch.sin(u)], dim=-1) + return phi * sqrt(2/self.n_dims) + + +class SmoothedRandomFourierFeatures(RandomFourierFeatures): + """Simply add a constant value to the dot product in order to avoid + possible numerical instabilities when the feature map is slightly + negative. + + Implements K(x, y) = exp(-|x-y|^2) + s. + + Arguments + --------- + query_dimensions: int, The input query dimensions in order to sample + the noise matrix + n_dims: int, The size of the feature map (should be divisible by 2) + (default: query_dimensions) + softmax_temp: float, The temerature for the Gaussian kernel + approximation exp(-t * |x-y|^2) + (default: 1/sqrt(query_dimensions)) + orthogonal: bool, When True the random matrix is initialized for + orthogonal random features to reduce the approximation + variance (default: False) + smoothing: float, The smoothing parameter to add to the dot product. + redraw: int, Redraw the random matrix every 'redraw' times + (default: 1) + deterministic_eval: bool, Only redraw the random matrix during training + (default: False) + """ + def __init__(self, query_dimensions, n_dims=None, softmax_temp=None, + orthogonal=False, smoothing=1.0, redraw=1, + deterministic_eval=False): + super(SmoothedRandomFourierFeatures, self).__init__( + query_dimensions, + n_dims=query_dimensions-1 if n_dims is None else n_dims-1, + softmax_temp=softmax_temp, + orthogonal=orthogonal, + redraw=redraw, + deterministic_eval=deterministic_eval + ) + self.smoothing = smoothing + + def forward(self, x): + y = super().forward(x) + smoothing = torch.full( + y.shape[:-1] + (1,), + self.smoothing, + dtype=y.dtype, + device=y.device + ) + return torch.cat([y, smoothing], dim=-1) + + +class Favor(RandomFourierFeatures): + """Positive orthogonal random features that approximate the softmax kernel. + + Basically implementation of Lemma 1 from "Rethinking Attention with + Performers". + + Arguments + --------- + query_dimensions: int, The input query dimensions in order to sample + the noise matrix + n_dims: int, The size of the feature map (should be divisible by 2) + (default: query_dimensions) + softmax_temp: float, The temerature for the softmax approximation + (default: 1/sqrt(query_dimensions)) + orthogonal: bool, If set to true then the random matrix should be + orthogonal which results in lower approximation variance + (default: True) + stabilize: bool, If set to True subtract the max norm from the + exponentials to make sure that there are no infinities. It + is equivalent to a robust implementation of softmax where + the max is subtracted before the exponentiation. + (default: False) + redraw: int, Redraw the random matrix every 'redraw' times + (default: 1) + deterministic_eval: bool, Only redraw the random matrix during training + (default: False) + """ + def __init__(self, query_dimensions, n_dims=None, softmax_temp=None, + orthogonal=True, stabilize=False, redraw=1, + deterministic_eval=False): + super(Favor, self).__init__(query_dimensions, n_dims=n_dims, + softmax_temp=softmax_temp, + orthogonal=orthogonal, redraw=redraw, + deterministic_eval=deterministic_eval) + self.stabilize = stabilize + + def _check_sequence_length(self, x): + """Check that the 2nd dimension is larger than the 3rd as a heuristic + that the sequence length will be larger than the number of heads. If + not simply warn of a possible bug.""" + if len(x.shape) != 4: + warnings.warn(("Favor.stabilize is set to True but the input " + "feature does not have the shape (N, L, H, D) " + "which may result in unexpected behaviour")) + + if x.shape[1] < x.shape[2]: + warnings.warn(("Favor.stabilize is set to True but the 2nd " + "dimension of the input is smaller than the 3rd " + "which could indicate that the sequence length and " + "the heads are flipped. This may result in incorrect " + "behaviour. The shape of the input is " + "{!r}.").format(x.shape)) + + def forward(self, x): + x = x * sqrt(self.softmax_temp) + norm_x_squared = torch.einsum("...d,...d->...", x, x).unsqueeze(-1) + u = x.unsqueeze(-2).matmul(self.omega).squeeze(-2) + + # Compute the offset for the exponential such that h(x) is multiplied + # in logspace. In particular, we multiply with exp(-norm_x_squared/2) + # and 1/sqrt(self.n_dims) + offset = norm_x_squared * 0.5 + 0.5 * log(self.n_dims) + + # If stabilize is True then add the max norm per sequence in order to + # ensure that exp_u1 and exp_u2 will be <1. + # + # NOTE: This is the only part of this feature map that assumes the + # 2nd dimension is the sequence length. We call the + # _check_sequence_length dimension function to be able to catch + # some possible bugs ahead of time. + if self.stabilize: + self._check_sequence_length(norm_x_squared) + offset = offset + norm_x_squared.max(1, keepdim=True)[0] + + exp_u1 = torch.exp(u - offset) + exp_u2 = torch.exp(-u - offset) + phi = torch.cat([exp_u1, exp_u2], dim=-1) + + return phi + + +class GeneralizedRandomFeatures(RandomFourierFeatures): + """Implements the generalized random Fourier features from Performers. + + It computes φ(χ) = [f(ω_1 χ), f(ω_2 χ), ..., f(ω_n χ)] where f(.) is the + passed in `kernel_fn`. + + Arguments + --------- + query_dimensions: int, The input query dimensions in order to sample + the noise matrix + n_dims: int, The size of the feature map (default: query_dimensions) + softmax_temp: float, A normalizer for the dot products that is + multiplied to the input features before the feature map + application (default: 1.0) + orthogonal: bool, If set to true then the random matrix should be + orthogonal which results in lower approximation variance + (default: True) + kernel_fn: callable, defines the f used for the feature map. + (default: relu) + redraw: int, Redraw the random matrix every 'redraw' times + (default: 1) + deterministic_eval: bool, Only redraw the random matrix during training + (default: False) + """ + def __init__(self, query_dimensions, n_dims=None, softmax_temp=1.0, + orthogonal=True, kernel_fn=torch.relu, redraw=1, + deterministic_eval=False): + super(GeneralizedRandomFeatures, self).__init__( + query_dimensions, + n_dims=2*query_dimensions if n_dims is None else 2*n_dims, + softmax_temp=softmax_temp, + orthogonal=orthogonal, + redraw=redraw, + deterministic_eval=deterministic_eval + ) + self.kernel_fn = kernel_fn + + def forward(self, x): + if self.softmax_temp != 1.0: + x = x * sqrt(self.softmax_temp) + u = x.unsqueeze(-2).matmul(self.omega).squeeze(-2) + return self.kernel_fn(u) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/hashing/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/hashing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c47f632859a34a7bb264c75e615c16f5b33e6926 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/hashing/__init__.py @@ -0,0 +1,31 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + + +import torch + +from .hash_cpu import compute_hashes as compute_hashes_cpu +try: + from .hash_cuda import compute_hashes as compute_hashes_cuda +except ImportError: + pass + + +def compute_hashes(X, A, H=None): + device = X.device + if H is None: + H = torch.zeros(len(X), dtype=torch.int64, device=device) + else: + H.zero_() + if A.shape[1] != X.shape[1] + 1: + raise ValueError("The hash requires a bias") + + if device.type == "cpu": + compute_hashes_cpu(X, A, H) + else: + compute_hashes_cuda(X, A, H) + + return H diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/hashing/hash_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/hashing/hash_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..48f5f56bccfc8d69c30e4c9c5262e89db173fded --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/hashing/hash_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4ddf5d319a08cd80b1cf5332799389911582f918c3eca85c4aab68afd465bf5 +size 133880 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/local_product/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/local_product/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdee2c9116bd6755af287b9ae8a32345ceafa6ab --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/local_product/__init__.py @@ -0,0 +1,97 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +import torch + +from .local_product_cpu import local_dot_product as local_dot_product_cpu, \ + local_dot_backward as local_dot_backward_cpu, \ + local_weighted_average as local_weighted_average_cpu, \ + local_weighted_average_backward as local_weighted_average_backward_cpu + +try: + from .local_product_cuda import \ + local_dot_product as local_dot_product_cuda, \ + local_dot_backward as local_dot_backward_cuda, \ + local_weighted_average as local_weighted_average_cuda, \ + local_weighted_average_backward as local_weighted_average_backward_cuda +except ImportError: + local_dot_product_cuda = None + local_dot_backward_cuda = None + local_weighted_average_cuda = None + local_weighted_average_backward_cuda = None + + +class LocalDotProduct(torch.autograd.Function): + """Compute the dot product of the queries and keys but only consider a + local neighborhood of each query.""" + dot = { + "cpu": local_dot_product_cpu, + "cuda": local_dot_product_cuda + } + dot_backward = { + "cpu": local_dot_backward_cpu, + "cuda": local_dot_backward_cuda + } + + @staticmethod + def forward(ctx, queries, keys, attn_mask, key_lengths, local_context): + # Save the inputs for the gradient computation + ctx.save_for_backward(queries, keys, key_lengths) + ctx.local_context = local_context + + return LocalDotProduct.dot[queries.device.type]( + queries, + keys, + attn_mask, + key_lengths, + local_context + ) + + @staticmethod + def backward(ctx, grad_input): + queries, keys, key_lengths = ctx.saved_tensors + local_context = ctx.local_context + + grads = LocalDotProduct.dot_backward[queries.device.type]( + queries, + keys, + key_lengths, + grad_input, + local_context + ) + + # plus 3 None for masks and local_context + return grads + (None, None, None) + + +class LocalWeightedAverage(torch.autograd.Function): + """Compute the weighted average of the values with the local attention.""" + avg = { + "cpu": local_weighted_average_cpu, + "cuda": local_weighted_average_cuda + } + avg_backward = { + "cpu": local_weighted_average_backward_cpu, + "cuda": local_weighted_average_backward_cuda + } + + @staticmethod + def forward(ctx, A, V): + # Save the inputs for the gradient computation + ctx.save_for_backward(A, V) + + return LocalWeightedAverage.avg[A.device.type](A, V) + + @staticmethod + def backward(ctx, grad_input): + A, V = ctx.saved_tensors + return LocalWeightedAverage.avg_backward[A.device.type]( + A, V, grad_input + ) + + +# Alias the autograd functions to python style snake case naming +local_dot_product = LocalDotProduct.apply +local_weighted_average = LocalWeightedAverage.apply diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/local_product/local_product_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/local_product/local_product_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..6d86eaf1afa8e0e969186c4ea4ffa3dec03d8d3e --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/local_product/local_product_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0be2baf179a6639bb75a6e3c0ed67206089457501b654c639df11e27d69f9d6d +size 158272 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/masking.py b/smi-ted/inference/smi_ted_light/fast_transformers/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf2007ee188e679418eb13327e3f842fa7d87d7 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/masking.py @@ -0,0 +1,206 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Create types of masks to be used in various places in transformers. + +- Full mask (any key masked for any query) +- Length mask (masking out everything after a length) +- Triangular causal mask (mask any key succeeding the query) + +All mask implementations should provide a single interface to be used by the +transformer layers and the attention layers. + +NOTE: In all cases the value 1 or True signifies what should be kept and not + what should be deleted/masked. +""" + +import torch + + +class BaseMask(object): + @property + def bool_matrix(self): + """Return a bool (uint8) matrix with 1s to all places that should be + kept.""" + raise NotImplementedError() + + @property + def float_matrix(self): + """Return the bool matrix as a float to be used as a multiplicative + mask for non softmax attentions.""" + if not hasattr(self, "_float_matrix"): + with torch.no_grad(): + self._float_matrix = self.bool_matrix.float() + return self._float_matrix + + @property + def lengths(self): + """If the matrix is of the following form + + 1 1 1 0 0 0 0 + 1 0 0 0 0 0 0 + 1 1 0 0 0 0 0 + + then return it as a vector of integers + + 3 1 2. + """ + if not hasattr(self, "_lengths"): + with torch.no_grad(): + lengths = self.bool_matrix.long().sum(dim=-1) + # make sure that the mask starts with 1s and continues with 0s + # this should be changed to something more efficient, however, + # I chose simplicity over efficiency since the LengthMask class + # will be used anyway (and the result is cached) + m = self.bool_matrix.view(-1, self.shape[-1]) + for i, l in enumerate(lengths.view(-1)): + if not torch.all(m[i, :l]): + raise ValueError("The mask is not a length mask") + self._lengths = lengths + return self._lengths + + @property + def shape(self): + """Return the shape of the boolean mask.""" + return self.bool_matrix.shape + + @property + def additive_matrix(self): + """Return a float matrix to be added to an attention matrix before + softmax.""" + if not hasattr(self, "_additive_matrix"): + with torch.no_grad(): + self._additive_matrix = torch.log(self.bool_matrix.float()) + return self._additive_matrix + + @property + def additive_matrix_finite(self): + """Same as additive_matrix but with -1e24 instead of infinity.""" + if not hasattr(self, "_additive_matrix_finite"): + with torch.no_grad(): + self._additive_matrix_finite = ( + (~self.bool_matrix).float() * (-1e24) + ) + return self._additive_matrix_finite + + @property + def all_ones(self): + """Return true if the mask is all ones.""" + if not hasattr(self, "_all_ones"): + with torch.no_grad(): + self._all_ones = torch.all(self.bool_matrix) + return self._all_ones + + @property + def lower_triangular(self): + """Return true if the attention is a triangular causal mask.""" + if not hasattr(self, "_lower_triangular"): + self._lower_triangular = False + with torch.no_grad(): + try: + lengths = self.lengths + if len(lengths.shape) == 1: + target = torch.arange( + 1, + len(lengths)+1, + device=lengths.device + ) + self._lower_triangular = torch.all(lengths == target) + except ValueError: + pass + return self._lower_triangular + + +class FullMask(BaseMask): + """Thin wrapper over a pytorch tensor that provides the BaseMask + interface. + + The arguments can be given both by keyword arguments and positional + arguments. To imitate function overloading, the constructor checks the type + of the first argument and if it is a tensor it treats it as the mask. + otherwise it assumes that it was the N argument. + + Arguments + --------- + mask: The mask as a PyTorch tensor. + N: The rows of the all True mask to be created if the mask argument is + not provided. + M: The columns of the all True mask to be created if the mask argument + is not provided. If N is given M defaults to N. + device: The device to create the mask in (defaults to cpu) + """ + def __init__(self, mask=None, N=None, M=None, device="cpu"): + # mask is a tensor so we ignore N and M + if mask is not None and isinstance(mask, torch.Tensor): + if mask.dtype != torch.bool: + raise ValueError("FullMask expects the mask to be bool") + with torch.no_grad(): + self._mask = mask.clone() + return + + # mask is an integer, N is an integer and M is None so assume they were + # passed as N, M + if mask is not None and M is None and isinstance(mask, int): + M = N + N = mask + + if N is not None: + M = M or N + with torch.no_grad(): + self._mask = torch.ones(N, M, dtype=torch.bool, device=device) + self._all_ones = True + return + + raise ValueError("Either mask or N should be provided") + + @property + def bool_matrix(self): + return self._mask + + +class LengthMask(BaseMask): + """Provide a BaseMask interface for lengths. Mostly to be used with + sequences of different lengths. + + Arguments + --------- + lengths: The lengths as a PyTorch long tensor + max_len: The maximum length for the mask (defaults to lengths.max()) + device: The device to be used for creating the masks (defaults to + lengths.device) + """ + def __init__(self, lengths, max_len=None, device=None): + self._device = device or lengths.device + with torch.no_grad(): + self._lengths = lengths.clone().to(self._device) + self._max_len = max_len or self._lengths.max() + + self._bool_matrix = None + self._all_ones = torch.all(self._lengths == self._max_len).item() + + @property + def bool_matrix(self): + if self._bool_matrix is None: + with torch.no_grad(): + indices = torch.arange(self._max_len, device=self._device) + self._bool_matrix = ( + indices.view(1, -1) < self._lengths.view(-1, 1) + ) + return self._bool_matrix + + +class TriangularCausalMask(LengthMask): + """A square matrix with everything masked out above the diagonal. + + Arguments + --------- + N: The size of the matrix + device: The device to create the mask in (defaults to cpu) + """ + def __init__(self, N, device="cpu"): + lengths = torch.arange(1, N+1, device=device) + super(TriangularCausalMask, self).__init__(lengths, N, device) + self._lower_triangular = True diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..047609dca866a62474e4bca4303e9aefb7a6388f --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implementations of transformers as recurrent functions.""" diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b792592201f282a0b3c0bb7922611b313e9f5e5 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/_utils.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0878f6fa079ff25ca4b5a68038d3bcd91283e382 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/_utils.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/transformers.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/transformers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0509f7251b68dc4639dd0cc56999dc80e7fe3f34 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/__pycache__/transformers.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/_utils.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a98be6614cb42996adb3e60d122f3ddd6e7f74b1 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/_utils.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +import warnings + + +def check_state(state=None, memory=None): + if memory is not None: + warnings.warn(("'memory' is deprecated for recurrent transformers " + " and will be removed in the future, use 'state' " + "instead"), DeprecationWarning) + if state is None: + state = memory + return state diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e72596ca313624b17f714b8a20ed009365c01f --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implementations of different types of autoregressive attention +mechanisms for self attention and cross attention.""" + +from .self_attention.attention_layer import RecurrentAttentionLayer +from .self_attention.full_attention import RecurrentFullAttention +from .self_attention.linear_attention import RecurrentLinearAttention + +from .cross_attention.attention_layer import RecurrentCrossAttentionLayer +from .cross_attention.full_attention import RecurrentCrossFullAttention +from .cross_attention.linear_attention import RecurrentCrossLinearAttention diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bbdbd8ca7d07762b25497db42aacd5949f80f6d Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3dae4993e2ffafa3940d1e25b18709f64396a6 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__init__.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Autoregressive implementations for cross attention as a recurrent module. + +The attention implementations in this module expect one input for query and a +sequence of inputs for keys and values. The sequence for the keys and values is +fixed for all queries. + +Example +-------- + + import torch + + from fast_transformers.recurrent.attention import \ + RecurrentCrossAttentionLayer, RecurrentCrossFullAttention + + att = RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), 16, 4) + state = None + x = torch.rand(8, 16) + memory = torch.rand(8, 64, 16) + for i in range(10): + x, state = att(x, memory, memory, state=state) +""" + +from .attention_layer import RecurrentCrossAttentionLayer +from .full_attention import RecurrentCrossFullAttention +from .linear_attention import RecurrentCrossLinearAttention diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6385e876aea455a77a07391cc3bcacec3e602e17 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/attention_layer.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/attention_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c621f9bbf7b99e3eb95e868a584c9f49359f6bd Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/attention_layer.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/full_attention.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/full_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf56b90ca7c15bff9c0679507f90be62daf310a6 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/full_attention.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/linear_attention.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/linear_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab430ab082fc25c23c7b49567f936522532eaa9 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/__pycache__/linear_attention.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/attention_layer.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/attention_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..f83a0798def79f79be5a32f03c794a1788264cb5 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/attention_layer.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Similar to the corresponding module in fast_transformers.attention, this +module performs all the query, key, value projections and output projections +leaving the implementation of the attention to the inner attention module. + +The crucial difference with respect to the self attention recurrent module +(fast_transformers.recurrent.attention.RecurrentAttentionLayer) is that it +doesn't recompute the projections for the keys and values if the state is not +None. +""" + +from torch.nn import Linear, Module + +from ....events import EventDispatcher + + +class RecurrentCrossAttentionLayer(Module): + """See fast_transformers.attention.attention_layer.AttentionLayer . + + The differences with the aforementioned module as well as the + RecurrentAttentionLayer are that this module projects the query every time + and the keys and values only the first time they are provided. + + Arguments + --------- + attention: Specific inner attention implementation that just computes a + weighted average of values given a similarity of queries and + keys. + d_model: The input feature dimensionality + n_heads: The number of heads for the multi head attention + d_keys: The dimensionality of the keys/queries + (default: d_model/n_heads) + d_values: The dimensionality of the values (default: d_model/n_heads) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RecurrentCrossAttentionLayer, self).__init__() + + # Fill d_keys and d_values + d_keys = d_keys or (d_model//n_heads) + d_values = d_values or (d_model//n_heads) + + self.inner_attention = attention + self.query_projection = Linear(d_model, d_keys * n_heads) + self.key_projection = Linear(d_model, d_keys * n_heads) + self.value_projection = Linear(d_model, d_values * n_heads) + self.out_projection = Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, query, keys, values, key_lengths, state=None): + """Attend to the keys and values based on the passed in query. + + In the argument description we make use of the following sizes + + - N: the batch size + - S: the sequence length of the keys and values + - D: The input feature dimensionality passed in the constructor as + 'd_model' + + Argument + -------- + query: (N, D) The tensor containing the queries + keys: (N, S, D) The tensor containing the keys + values: (N, S, D) The tensor containing the values + key_lengths: A fast_transformers.masking.BaseMask implementation + that defines the length of each key/value sequence + state: The state varies depending on the inner attention + implementation, but if it is not None then the keys and + values are ignored + """ + #Extract some shapes + N, _ = query.shape + H = self.n_heads + + # Project the query + query = self.query_projection(query).view(N, H, -1) + + # Project the keys and values if there is no state + if state is None: + _, S, _ = keys.shape + keys = self.key_projection(keys).view(N, S, H, -1) + values = self.value_projection(values).view(N, S, H, -1) + else: + keys = None + values = None + + new_value, state = self.inner_attention( + query, + keys, + values, + key_lengths, + state=state + ) + new_value = new_value.view(N, -1) + + # Project the output and return + return self.out_projection(new_value), state diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/full_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/full_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8652322ea024535be8f14ba64a8c43ad7e568569 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/full_attention.py @@ -0,0 +1,75 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Implement the typical softmax attention as a recurrent cross attention +module to speed up autoregressive decoding.""" + +from math import sqrt + +import torch +from torch.nn import Dropout, Module + +from ....attention_registry import RecurrentCrossAttentionRegistry, Optional, \ + Float, EventDispatcherInstance +from ....events import EventDispatcher, AttentionEvent + + +class RecurrentCrossFullAttention(Module): + """Implement autoregressive softmax cross attention as a recurrent + module. + + Arguments + --------- + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + + def __init__(self, softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(RecurrentCrossFullAttention, self).__init__() + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, query, keys, values, key_lengths, state=None): + # Extract some shapes and compute the temperature + N, H, E = query.shape + softmax_temp = self.softmax_temp or 1. / sqrt(E) + + # Extract the keys and values either from the arguments or the state + if state is not None: + keys, values = state + + # Compute the unnormalized attention and apply the key length mask + QK = torch.einsum("nhe,nshe->nsh", query, keys) + QK = QK + key_lengths.additive_matrix[:, :, None] + + # Compute the attention and the weighted average + A = self.dropout(torch.softmax(softmax_temp * QK, dim=1)) + V = torch.einsum("nsh,nshd->nhd", A, values) + + # Let the world know of the attention matrix + self.event_dispatcher.dispatch(AttentionEvent(self, A)) + + # Make sure that we return a contiguous value + return V.contiguous(), [keys, values] + + +# Register the attention implementation so that it becomes available in our +# builders +RecurrentCrossAttentionRegistry.register( + "full", RecurrentCrossFullAttention, + [ + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/linear_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a043cbe025cbd25dfeeae93298f04f9ecfd3fa --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/cross_attention/linear_attention.py @@ -0,0 +1,79 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Implement unmasked linear attention as a recurrent cross attention module to +speed up autoregressive decoding.""" + +import torch +from torch.nn import Module + +from ....attention_registry import RecurrentCrossAttentionRegistry, Optional, Int, \ + Callable, EventDispatcherInstance +from ....events import EventDispatcher +from ....feature_maps import elu_feature_map + + +class RecurrentCrossLinearAttention(Module): + """Implement autoregressive linear cross attention as a recurrent + module. + + See fast_transformers.attention.linear_attention.LinearAttention . + + Arguments + --------- + feature_map: callable, a callable that applies the feature map to the + last dimension of a tensor (default: elu(x)+1) + eps: float, a small number to ensure the numerical stability of the + denominator (default: 1e-6) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, query_dimensions, feature_map=None, eps=1e-6, + event_dispatcher=""): + super(RecurrentCrossLinearAttention, self).__init__() + self.feature_map = ( + feature_map(query_dimensions) if feature_map else + elu_feature_map(query_dimensions) + ) + self.eps = eps + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, query, keys, values, key_lengths, state=None): + # If this is a new sequence re initialize the feature map + if state is None: + self.feature_map.new_feature_map(query.device) + + # Compute the feature representation of the query + Q = self.feature_map.forward_queries(query) + + # If the state is not given compute the key-value matrix and the + # normalizers, namely compute whatever is needed in order to attend to + # keys and values with a given query. + if state is None: + K = self.feature_map.forward_keys(keys) + K = K * key_lengths.float_matrix[:, :, None, None] + S = torch.einsum("nshd,nshm->nhmd", K, values) + Z = K.sum(dim=1) + else: + S, Z = state + + # Given S and Z now we can efficiently compute the new value + QZ = 1/(torch.einsum("nhd,nhd->nh", Q, Z)+self.eps) + V = torch.einsum("nhd,nhmd,nh->nhm", Q, S, QZ) + + return V.contiguous(), [S, Z] + + +# Register the attention implementation so that it becomes available in our +# builders +RecurrentCrossAttentionRegistry.register( + "linear", RecurrentCrossLinearAttention, + [ + ("query_dimensions", Int), + ("feature_map", Optional(Callable)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5dbf0adcad7c8710a3b9b8c1722e64b1346653f --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__init__.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Autoregressive implementations for self attention as a recurrent module. + +The attention implementations in this module expect one input for query, one +for key and one for value and attend to all the keys and values seen so far. No +masking is necessary as an implicit lower triangular attention mask is assumed +in all cases. + +Example +------- + + import torch + + from fast_transformers.recurrent.attention import \ + RecurrentAttentionLayer, RecurrentFullAttention + + att = RecurrentAttentionLayer(RecurrentFullAttention(), 16, 4) + state = None + x = torch.rand(8, 16) + for i in range(10): + x, state = att(x, x, x, state=state) +""" + +from .attention_layer import RecurrentAttentionLayer +from .full_attention import RecurrentFullAttention +from .linear_attention import RecurrentLinearAttention diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/__init__.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9556414427ae41e7ced7583fe8cb278b445fdbd7 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/attention_layer.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/attention_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32e01003c254147b8d2b2a92d3a3f7fcb7897bf2 Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/attention_layer.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/full_attention.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/full_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9735eff462aa65e9c4238d73cba8bae1165b233f Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/full_attention.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/linear_attention.cpython-310.pyc b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/linear_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4451cf0b29feb0b4eaa49a2bb012501bf4f9619a Binary files /dev/null and b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/__pycache__/linear_attention.cpython-310.pyc differ diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/attention_layer.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/attention_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d63c48a03a24d9f6ec78728d1120d03f31e82a7 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/attention_layer.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Similar to the corresponding module in fast_transformers.attention, this +module performs all the query, key, value projections and output projections +leaving the implementation of the attention to the inner attention module.""" + +from torch.nn import Linear, Module + +from ....events import EventDispatcher +from ..._utils import check_state + + +class RecurrentAttentionLayer(Module): + """See fast_transformers.attention.attention_layer.AttentionLayer. + + The only difference with the corresponding module is that this projects + only one input and then calls the inner attention with the provided + previous state. + + Arguments + --------- + attention: Specific inner attention implementation that just computes a + weighted average of values given a similarity of queries and + keys. + d_model: The input feature dimensionality + n_heads: The number of heads for the multi head attention + d_keys: The dimensionality of the keys/queries + (default: d_model/n_heads) + d_values: The dimensionality of the values (default: d_model/n_heads) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None, event_dispatcher=""): + super(RecurrentAttentionLayer, self).__init__() + + # Fill d_keys and d_values + d_keys = d_keys or (d_model//n_heads) + d_values = d_values or (d_model//n_heads) + + self.inner_attention = attention + self.query_projection = Linear(d_model, d_keys * n_heads) + self.key_projection = Linear(d_model, d_keys * n_heads) + self.value_projection = Linear(d_model, d_values * n_heads) + self.out_projection = Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, query, key, value, state=None, memory=None): + """Apply attention to the passed in query/key/value after projecting + them to multiple heads. + + In the argument description we make use of the following sizes + + - N: the batch size + - D: The input feature dimensionality passed in the constructor as + 'd_model' + + Arguments + --------- + query: (N, D) The tensor containing the queries + key: (N, D) The tensor containing the keys + value: (N, D) The tensor containing the values + state: The state varies depending on the inner attention implementation + memory: **Deprecated** and replaced by state + + Returns + ------- + The new value for each query as a tensor of shape (N, D). + """ + # Normalize the state/memory + state = check_state(state, memory) + + # Project the queries/keys/values + query = self.query_projection(query) + key = self.key_projection(key) + value = self.value_projection(value) + + # Reshape them into many heads and compute the attention + N, D = query.shape + H = self.n_heads + new_value, state = self.inner_attention( + query.view(N, H, -1), + key.view(N, H, -1), + value.view(N, H, -1), + state + ) + new_value = new_value.view(N, -1) + + # Project the output and return + return self.out_projection(new_value), state diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/full_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/full_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..95d1b3d9d7ea7a2a5910980956efac93acc8686c --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/full_attention.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement the typical softmax attention as a recurrent module to speed up +autoregressive inference. See fast_transformers.attention.full_attention .""" + +from math import sqrt + +import torch +from torch.nn import Dropout, Module + +from ....attention_registry import RecurrentAttentionRegistry, Optional, \ + Float, EventDispatcherInstance +from ....events import EventDispatcher, AttentionEvent +from ..._utils import check_state + + +class RecurrentFullAttention(Module): + """Implement the full softmax attention as a recurrent module. + + Arguments + --------- + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, softmax_temp=None, attention_dropout=0.1, + event_dispatcher=""): + super(RecurrentFullAttention, self).__init__() + self.softmax_temp = softmax_temp + self.dropout = Dropout(attention_dropout) + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, query, key, value, state=None, memory=None): + # Normalize state/memory + state = check_state(state, memory) + + # Extract some shapes and compute the temperature + N, H, E = query.shape + _, _, D = value.shape + softmax_temp = self.softmax_temp or 1./sqrt(E) + + # Aggregate the list of keys and values + if state is not None: + keys, values = state + keys = torch.cat([keys, key[:, :, None]], dim=2) + values = torch.cat([values, value[:, :, None]], dim=2) + else: + keys = key[:, :, None] + values = value[:, :, None] + + # Compute the unnormalized attention + QK = torch.einsum("nhe,nhse->nhs", query, keys) + + # Compute the attention and the weighted average + A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) + V = torch.einsum("nhs,nhsd->nhd", A, values).contiguous() + + # Let the world know of the attention matrix + self.event_dispatcher.dispatch(AttentionEvent(self, A)) + + # Make sure that what we return is contiguous + return V, [keys, values] + + +# Register the attention implementation so that it becomes available in our +# builders +RecurrentAttentionRegistry.register( + "full", RecurrentFullAttention, + [ + ("softmax_temp", Optional(Float)), + ("attention_dropout", Optional(Float, 0.1)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/linear_attention.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d35cec1e622e012214f1a8f73d97af231322c5 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/attention/self_attention/linear_attention.py @@ -0,0 +1,110 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement the causally masked linear attention as a recurrent model.""" + +import torch +from torch.nn import Module + +from ....attention_registry import RecurrentAttentionRegistry, Optional, Int, \ + Callable, EventDispatcherInstance +from ....events import EventDispatcher +from ....feature_maps import elu_feature_map +from ..._utils import check_state + + +class RecurrentLinearAttention(Module): + """Implement fast_transformers.attention.causal_linear_attention as a + fixed-dimensional state recurrent model. + + See fast_transformers.attention.linear_attention and + fast_transformers.attention.causal_linear_attention for the general concept + of replacing the softmax with feature maps. + + Arguments + --------- + feature_map: callable, a callable that applies the feature map to the + last dimension of a tensor (default: elu(x)+1) + eps: float, a small number to ensure the numerical stability of the + denominator (default: 1e-6) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, query_dimensions, feature_map=None, eps=1e-6, + event_dispatcher=""): + super(RecurrentLinearAttention, self).__init__() + self.feature_map = ( + feature_map(query_dimensions) if feature_map else + elu_feature_map(query_dimensions) + ) + self.eps = eps + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, query, key, value, state=None, memory=None): + # Normalize state/memory + state = check_state(state, memory) + + # If this is a new sequence reinitialize the feature map + if state is None: + self.feature_map.new_feature_map(query.device) + + # Apply the feature map to the query and key + Q = self.feature_map.forward_queries(query) + K = self.feature_map.forward_keys(key) + + # Extract some shapes + N, H, D = Q.shape + _, _, M = value.shape + + # Extract the memory or initialize it + if state is None: + Si = query.new_zeros((N, H, D, M)) + Zi = query.new_zeros((N, H, D)) + else: + Si, Zi = state + + # Ensure the batch size did not change + if len(Si) != N: + raise ValueError("The batch size changed during iteration") + + # Update the internal state + # + # NOTE: The if clause is added due to GitHub PR #10. Simply using the + # following two lines does not perform the operation in place which + # means it is slower for inference. + if K.grad_fn is not None or value.grad_fn is not None: + Zi = Zi + K + Si = Si + torch.einsum("nhd,nhm->nhdm", K, value) + else: + Zi += K + Si += torch.einsum("nhd,nhm->nhdm", K, value) + + # Compute the output + Z = 1. / (torch.einsum("nhd,nhd->nh", Q, Zi) + self.eps) + V = torch.einsum("nhd,nhdm,nh->nhm", Q, Si, Z) + + return V, [Si, Zi] + + +# Register the attention implementation so that it becomes available in our +# builders +RecurrentAttentionRegistry.register( + "linear", RecurrentLinearAttention, + [ + ("query_dimensions", Int), + ("feature_map", Optional(Callable)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) +RecurrentAttentionRegistry.register( + "causal-linear", RecurrentLinearAttention, + [ + ("query_dimensions", Int), + ("feature_map", Optional(Callable)), + ("event_dispatcher", Optional(EventDispatcherInstance, "")) + ] +) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/transformers.py b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2448aec48acb425227423897ec20ab7dc3872e --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/recurrent/transformers.py @@ -0,0 +1,279 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement transformer encoders and decoders as RNNs that will be used with +different recurrent attention mechanisms. + +In all cases there exists no sequence dimension and the shapes are batch x +heads x dims. + +This module's interface is designed with the linear attention in mind. The +interface is subject to change given the implementation of other recurrent +attentions. +""" + +import warnings + +import torch +from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList +import torch.nn.functional as F + +from ..events import EventDispatcher +from ..masking import LengthMask +from ._utils import check_state + + +class RecurrentTransformerEncoderLayer(Module): + """Attention to the previous inputs and feed forward with skip connections. + + This transformer encoder layer is the recurrent dual of + fast_transformers.transformers.TransformerEncoderLayer . The results should + be identical given the same inputs and a lower triangular mask. + + Arguments + --------- + attention: The attention implementation to use given as a nn.Module + d_model: The input feature dimensionality + d_ff: The dimensionality of the intermediate features after the + attention (default: d_model*4) + dropout: The dropout rate to apply to the intermediate features + (default: 0.1) + activation: {'relu', 'gelu'} Which activation to use for the feed + forward part of the layer (default: relu) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, + activation="relu", event_dispatcher=""): + super(RecurrentTransformerEncoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.attention = attention + self.linear1 = Linear(d_model, d_ff) + self.linear2 = Linear(d_ff, d_model) + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout = Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, state=None, memory=None): + """Apply the transformer encoder to the input x using the provided + memory. + + Arguments + --------- + x: The input features of shape (N, E) where N is the batch size and + E is d_model passed in the constructor + state: The state can vary depending on the attention implementation + memory: **Deprecated** name for the state argument + """ + # Normalize the state name + state = check_state(state, memory) + + # Run the self attention and add it to the input + x2, state = self.attention(x, x, x, state) + x = x + self.dropout(x2) + + # Run the fully connected part of the layer + y = x = self.norm1(x) + y = self.dropout(self.activation(self.linear1(y))) + y = self.dropout(self.linear2(y)) + + return self.norm2(x+y), state + + +class RecurrentTransformerEncoder(Module): + """RecurrentTransformerEncoder is a sequence of + RecurrentTransformerEncoderLayer instances. + + RecurrentTransformerEncoder keeps a separate state per + RecurrentTransformerEncoderLayer. + + Arguments + --------- + layers: list, RecurrentTransformerEncoderLayer instances or instances + that implement the same interface + norm_layer: A normalization layer to be applied to the final output + (default: None which means no normalization) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, layers, norm_layer=None, event_dispatcher=""): + super(RecurrentTransformerEncoder, self).__init__() + self.layers = ModuleList(layers) + self.norm = norm_layer + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, state=None, memory=None): + """Apply all recurrent transformer layers to the input x using the + provided state. + + Arguments + --------- + x: The input features of shape (N, E) where N is the batch size and + E is d_model passed in the constructor of each recurrent + transformer encoder layer + state: A list of objects to be passed to each recurrent + transformer encoder layer + memory: **Deprecated** name for the state argument + """ + # Initialize the memory to None if not given + state = check_state(state, memory) + if state is None: + state = [None]*len(self.layers) + + # Apply all the transformers + for i, layer in enumerate(self.layers): + x, s = layer(x, state[i]) + state[i] = s + + # Apply the normalization if needed + if self.norm is not None: + x = self.norm(x) + + return x, state + + +class RecurrentTransformerDecoderLayer(Module): + """Attention to the previous inputs and a preprocessed memory. + + This transformer decoder layer is the recurrent dual of + fast_transformers.transformers.TransformerDecoderLayer . The results should + be identical given the same inputs and a lower triangular mask for x_mask. + + Arguments + --------- + self_attention: The attention implementation to use for self attention + given as a nn.Module + cross_attention: The attention implementation to use for cross + attention given as a nn.Module + d_model: The input feature dimensionality + d_ff: The dimensionality of the intermediate features after the + attention (default: d_model*4) + dropout: The dropout rate to apply to the intermediate features + (default: 0.1) + activation: {'relu', 'gelu'} Which activation to use for the feed + forward part of the layer (default: relu) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, + dropout=0.1, activation="relu", event_dispatcher=""): + super(RecurrentTransformerDecoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.linear1 = Linear(d_model, d_ff) + self.linear2 = Linear(d_ff, d_model) + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout = Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, memory, memory_length_mask=None, state=None): + """Apply the transformer decoder to the input x and also attend to + memory. + + Note the memory mask is assumed to be a full mask. + + Arguments + --------- + x: The input features of shape (N, E) where N is the batch size and + E is d_model passed in the constructor + memory: A sequence of features (N, S, E) that the input will attend + to. S is the sequence length and E is the same as for x. + memory_length_mask: An implementation of a BaseMask that encodes + how many elements each memory sequence in the + batch consists of. + state: The state varies depending on the attention implementations + but it allows for recurrent implementation. + """ + # Normalize the mask + N = x.shape[0] + L = memory.shape[1] + memory_length_mask = memory_length_mask or \ + LengthMask(x.new_full((N,), L, dtype=torch.int64)) + + # Extract the individual states for the self attention and the cross + # attention + self_state, cross_state = state or [None, None] + + # First apply the self attention and add it to the input + x2, self_state = self.self_attention(x, x, x, state=self_state) + x = self.norm1(x + self.dropout(x2)) + + # Secondly apply the cross attention and add it to the previous output + x2, cross_state = self.cross_attention( + x, memory, memory, memory_length_mask, state=cross_state + ) + x = self.norm2(x + self.dropout(x2)) + + # Finally run the fully connected part of the layer + y = x + y = self.dropout(self.activation(self.linear1(y))) + y = self.dropout(self.linear2(y)) + + return self.norm3(x+y), [self_state, cross_state] + + +class RecurrentTransformerDecoder(Module): + """RecurrentTransformerDecoder is little more than a sequence of + RecurrentTransformerDecoderLayer instances. + + Simlar to the recurrent encoder a separate state is kept per decoder layer. + + Arguments + --------- + layers: list, RecurrentTransformerDecoderLayer instances or instances + that implement the same interface + norm_layer: A normalization layer to be applied to the final output + (default: None which means no normalization) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, layers, norm_layer=None, event_dispatcher=""): + super(RecurrentTransformerDecoder, self).__init__() + self.layers = ModuleList(layers) + self.norm = norm_layer + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, memory, memory_length_mask=None, state=None): + """Apply all recurrent transformer layers to the input x using the + provided state. + + Arguments + --------- + x: The input features of shape (N, E) where N is the batch size and + E is d_model passed in the constructor + memory: A sequence of features (N, S, E) that the input will attend + to. S is the sequence length and E is the same as for x. + memory_length_mask: An implementation of a BaseMask that encodes + how many elements each memory sequence in the + batch consists of + state: A list of objects to be passed to each recurrent + transformer decoder layer + """ + # Initialize the state to None if not given + if state is None: + state = [None]*len(self.layers) + + # Apply all the transformers + for i, layer in enumerate(self.layers): + x, s = layer(x, memory, memory_length_mask=memory_length_mask, + state=state[i]) + state[i] = s + + # Apply the normalization if needed + if self.norm is not None: + x = self.norm(x) + + return x, state diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/__init__.py b/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e353f7a05a2cb950e144f3b58457601e3d13941a --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/__init__.py @@ -0,0 +1,399 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + + +import torch + +from .sparse_product_cpu import \ + sparse_dot_product as sparse_dot_product_cpu, \ + sparse_dot_backward as sparse_dot_backward_cpu, \ + sparse_weighted_average as sparse_weighted_average_cpu, \ + sparse_weighted_average_backward as sparse_weighted_average_backward_cpu +try: + from .sparse_product_cuda import \ + sparse_dot_product as sparse_dot_product_cuda, \ + sparse_dot_backward as sparse_dot_backward_cuda, \ + sparse_weighted_average as sparse_weighted_average_cuda, \ + sparse_weighted_average_backward as \ + sparse_weighted_average_backward_cuda +except ImportError: + sparse_dot_product_cuda = None + sparse_dot_backward_cuda = None + sparse_weighted_average_cuda = None + sparse_weighted_average_backward_cuda = None + +from .clustered_sparse_product_cpu import \ + clustered_sparse_dot_product as clustered_sparse_dot_product_cpu, \ + clustered_sparse_dot_backward as clustered_sparse_dot_backward_cpu, \ + clustered_sparse_weighted_average as \ + clustered_sparse_weighted_average_cpu, \ + clustered_sparse_weighted_average_backward as \ + clustered_sparse_weighted_average_backward_cpu + +try: + from .clustered_sparse_product_cuda import \ + clustered_sparse_dot_product as clustered_sparse_dot_product_cuda, \ + clustered_sparse_dot_backward as clustered_sparse_dot_backward_cuda, \ + clustered_sparse_weighted_average as \ + clustered_sparse_weighted_average_cuda, \ + clustered_sparse_weighted_average_backward as \ + clustered_sparse_weighted_average_backward_cuda +except ImportError: + clustered_sparse_dot_product_cuda = None + clustered_sparse_dot_backward_cuda = None + clustered_sparse_weighted_average_cuda = None + clustered_sparse_weighted_average_backward_cuda = None + + +class SparseDotProduct(torch.autograd.Function): + """Compute the dot products only at the positions specified by topk.""" + dot = { + "cpu": sparse_dot_product_cpu, + "cuda": sparse_dot_product_cuda + } + dot_backward = { + "cpu": sparse_dot_backward_cpu, + "cuda": sparse_dot_backward_cuda + } + + @staticmethod + def forward(ctx, Q, K, topk): + # Save the inputs to compute the gradient + ctx.save_for_backward(Q, K, topk) + + # Create the output tensor + device = Q.device + N, H, L, E = Q.shape + _, _, _, k = topk.shape + product = torch.empty((N, H, L, k), device=device) + + # Actually perform the dot product + SparseDotProduct.dot[device.type](Q, K, topk, product) + + return product + + @staticmethod + def backward(ctx, grad_output): + # Extract the saved tensors and allocate memory for the gradients + Q, K, topk = ctx.saved_tensors + grad_Q = torch.zeros_like(Q) + grad_K = torch.zeros_like(K) + + SparseDotProduct.dot_backward[Q.device.type]( + Q, + K, + topk, + grad_output, + grad_Q, + grad_K + ) + + return grad_Q, grad_K, None + + +class SparseWeightedAverage(torch.autograd.Function): + """Compute the weighted average only for the topk values.""" + avg = { + "cpu": sparse_weighted_average_cpu, + "cuda": sparse_weighted_average_cuda + } + avg_backward = { + "cpu": sparse_weighted_average_backward_cpu, + "cuda": sparse_weighted_average_backward_cuda + } + + @staticmethod + def forward(ctx, weights, values, topk): + # Save the tensors to compute the gradient + ctx.save_for_backward(weights, values, topk) + + # Allocate the output tensor + N, H, L, _ = weights.shape + _, _, _, E = values.shape + output = values.new_zeros(N, H, L, E) + + # Compute the average + SparseWeightedAverage.avg[weights.device.type]( + weights, + values, + topk, + output + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + # Extract the saved tensors and allocate memory for the gradients + weights, values, topk = ctx.saved_tensors + grad_weights = torch.zeros_like(weights) + grad_values = torch.zeros_like(values) + + if grad_output.stride()[-1] != 1: + grad_output = grad_output.contiguous() + + SparseWeightedAverage.avg_backward[weights.device.type]( + weights, + values, + topk, + grad_output, + grad_weights, + grad_values + ) + + return grad_weights, grad_values, None + + +class ClusteredSparseDotProduct(torch.autograd.Function): + """Compute the dot products only at the positions specified by topk.""" + dot = { + "cpu": clustered_sparse_dot_product_cpu, + "cuda": clustered_sparse_dot_product_cuda + } + dot_backward = { + "cpu": clustered_sparse_dot_backward_cpu, + "cuda": clustered_sparse_dot_backward_cuda + } + + @staticmethod + def forward(ctx, Q, K, topk, groups, counts, lengths): + # Save the inputs to compute the gradient + ctx.save_for_backward(Q, K, topk, groups, counts) + + device = Q.device + N, H, L, E = Q.shape + _, _, C, k = topk.shape + + # Create the output tensor + product = torch.zeros((N, H, L, k), device=device) + + # Unfortunately the cpu and gpu interfaces are different so + # the entire call is surrounded by if-else block + if device.type == "cpu": + ClusteredSparseDotProduct.dot[device.type]( + Q, + K, + groups, + topk, + product + ) + + else: + # Allocate bookkeeping parameters to facilitate the kernel + with torch.no_grad(): + Q_pb = 16 + block_counts = (counts + Q_pb - 1) // Q_pb + block_counts = block_counts.int() + block_counts_cumsum = block_counts.view(-1).cumsum(-1).view(N, H, C).int() + indx_maps = torch.ones( + (block_counts.sum(), 4), + device=Q.device, + dtype=torch.int32 + ) + counts_cumsum = counts.cumsum(-1).int() + total_blocks = block_counts.sum().item() + + # Actually perform the dot product + ClusteredSparseDotProduct.dot[device.type]( + Q, + K, + topk.int(), + counts_cumsum - counts, + counts_cumsum, + block_counts, + block_counts_cumsum, + total_blocks, + indx_maps, + product + ) + + return product + + @staticmethod + def backward(ctx, grad_output): + Q, K, topk, groups, counts = ctx.saved_tensors + device = Q.device + # Extract the saved tensors and allocate memory for the gradients + grad_Q = torch.zeros_like(Q) + grad_K = torch.zeros_like(K) + + # Unfortunately the cpu and gpu interfaces are different so + # the entire call is surrounded by if-else block + if device.type == "cpu": + ClusteredSparseDotProduct.dot_backward[Q.device.type]( + Q, + K, + groups, + topk, + grad_output, + grad_Q, + grad_K + ) + + else: + N, H, L, E = Q.shape + _, _, C, k = topk.shape + # Allocate bookkeeping parameters to facilitate the kernel + with torch.no_grad(): + Q_pb = 16 + block_counts = (counts + Q_pb - 1) // Q_pb + block_counts = block_counts.int() + block_counts_cumsum = block_counts.view(-1).cumsum(-1).view(N, H, C).int() + indx_maps = torch.ones( + (block_counts.sum(), 4), + device=Q.device, + dtype=torch.int32 + ) + + counts_cumsum = counts.cumsum(-1).int() + total_blocks = block_counts.sum().item() + + # Actually perform the backward pass + ClusteredSparseDotProduct.dot_backward[Q.device.type]( + Q, + K, + groups.int(), + topk.int(), + grad_output, + grad_Q, + grad_K, + counts_cumsum - counts, + counts_cumsum, + block_counts, + block_counts_cumsum, + total_blocks, + indx_maps + ) + + return grad_Q, grad_K, None, None, None, None, None + + +class ClusteredSparseWeightedAverage(torch.autograd.Function): + """Compute the weighted average only for the topk values.""" + avg = { + "cpu": clustered_sparse_weighted_average_cpu, + "cuda": clustered_sparse_weighted_average_cuda + } + avg_backward = { + "cpu": clustered_sparse_weighted_average_backward_cpu, + "cuda": clustered_sparse_weighted_average_backward_cuda + } + + @staticmethod + def forward(ctx, weights, values, topk, groups, counts): + # Save the tensors to compute the gradient + ctx.save_for_backward(weights, values, topk, groups, counts) + + # Allocate the output tensor + N, H, L, _ = weights.shape + _, _, _, E = values.shape + _, _, C, _ = topk.shape + output = values.new_zeros(N, H, L, E) + device = weights.device + + if device.type == "cpu": + # Compute the average + ClusteredSparseWeightedAverage.avg[weights.device.type]( + weights, + values, + groups, + topk, + output + ) + else: + # Bookkeeping parameters to facilitate the GPU cuda kernel + with torch.no_grad(): + Q_pb = 16 + block_counts = (counts + Q_pb - 1) // Q_pb + block_counts = block_counts.int() + block_counts_cumsum = block_counts.view(-1).cumsum(-1).view(N, H, C).int() + indx_maps = torch.ones( + (block_counts.sum(), 4), + device=weights.device, + dtype=torch.int32 + ) + counts_cumsum = counts.cumsum(-1).int() + total_blocks = block_counts.sum().item() + + # Compute the average + ClusteredSparseWeightedAverage.avg[device.type]( + weights, + values, + topk.int(), + output, + counts_cumsum - counts, + counts_cumsum, + block_counts, + block_counts_cumsum, + total_blocks, + indx_maps + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + # Extract the saved tensors and allocate memory for the gradients + weights, values, topk, groups, counts = ctx.saved_tensors + grad_weights = torch.zeros_like(weights) + grad_values = torch.zeros_like(values) + + if grad_output.stride()[-1] != 1: + grad_output = grad_output.contiguous() + + device = weights.device + if device.type == "cpu": + ClusteredSparseWeightedAverage.avg_backward[weights.device.type]( + weights, + values, + groups, + topk, + grad_output, + grad_weights, + grad_values + ) + else: + # Bookkeeping parameters to facilitate the cuda kernel + with torch.no_grad(): + N, H, C = counts.shape + Q_pb = 16 + block_counts = (counts + Q_pb - 1) // Q_pb + block_counts = block_counts.int() + block_counts_cumsum = block_counts.view(-1).cumsum(-1).view(N, H, C).int() + + indx_maps = torch.ones( + (block_counts.sum(), 4), + device=weights.device, + dtype=torch.int32 + ) + counts_cumsum = counts.cumsum(-1).int() + total_blocks = block_counts.sum().item() + + # Do sparse weighted average backward pass + ClusteredSparseWeightedAverage.avg_backward[device.type]( + weights, + values, + topk.int(), + grad_output, + grad_weights, + grad_values, + counts_cumsum - counts, + counts_cumsum, + block_counts, + block_counts_cumsum, + total_blocks, + indx_maps + ) + return grad_weights, grad_values, None, None, None, None + + +# Alias the autograd functions to python style snake case naming +clustered_sparse_dot_product = ClusteredSparseDotProduct.apply +clustered_sparse_weighted_average = ClusteredSparseWeightedAverage.apply + +# Alias the autograd functions to python style snake case naming +sparse_dot_product = SparseDotProduct.apply +sparse_weighted_average = SparseWeightedAverage.apply diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/clustered_sparse_product_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/clustered_sparse_product_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..01ed893b3c82e657cddce51cd9a3b239e158cdda --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/clustered_sparse_product_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1db5513ec7fa249afd2d85bdfbfc7fde7c1399a6e17edb0275a72433f87ab69e +size 146680 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/sparse_product_cpu.cpython-39-x86_64-linux-gnu.so b/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/sparse_product_cpu.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..ca198eb0a50f4ab653c101e0e31093ba6d0bad34 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/sparse_product/sparse_product_cpu.cpython-39-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cd628d0dd6d707f26a216bbc844a90481fba85c3dbd1e91cae0942153c6dde4 +size 145952 diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/transformers.py b/smi-ted/inference/smi_ted_light/fast_transformers/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..e726edba31f20751c4291125bbf12e3ac8f55c1d --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/transformers.py @@ -0,0 +1,294 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""Implement transformer encoders and decoders that are going to be used with +different attention mechanisms. + +In all cases the batch dimension is first and the sequence dimension is second. +""" + +import torch +from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList +import torch.nn.functional as F + +from .events import EventDispatcher +from .masking import FullMask, LengthMask + + +class TransformerEncoderLayer(Module): + """Self attention and feed forward network with skip connections. + + This transformer encoder layer implements the same encoder layer as + PyTorch but is a bit more open for extension by receiving the attention + implementation as a constructor argument. + + Arguments + --------- + attention: The attention implementation to use given as a nn.Module + d_model: The input feature dimensionality + d_ff: The dimensionality of the intermediate features after the + attention (default: d_model*4) + dropout: The dropout rate to apply to the intermediate features + (default: 0.1) + activation: {'relu', 'gelu'} Which activation to use for the feed + forward part of the layer (default: relu) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, + activation="relu", event_dispatcher=""): + super(TransformerEncoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.attention = attention + self.linear1 = Linear(d_model, d_ff) + self.linear2 = Linear(d_ff, d_model) + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.dropout = Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, attn_mask=None, length_mask=None): + """Apply the transformer encoder to the input x. + + Arguments + --------- + x: The input features of shape (N, L, E) where N is the batch size, + L is the sequence length (padded) and E is d_model passed in the + constructor. + attn_mask: An implementation of fast_transformers.masking.BaseMask + that encodes where each element of x can attend to. + length_mask: An implementation of + fast_transformers.masking.BaseMask that encodes how + many elements each sequence in the batch consists of. + """ + # Normalize the masks + N = x.shape[0] + L = x.shape[1] + attn_mask = attn_mask or FullMask(L, device=x.device) + length_mask = length_mask or \ + LengthMask(x.new_full((N,), L, dtype=torch.int64)) + + # Run self attention and add it to the input + x = x + self.dropout(self.attention( + x, x, x, + attn_mask=attn_mask, + query_lengths=length_mask, + key_lengths=length_mask + )) + + # Run the fully connected part of the layer + y = x = self.norm1(x) + y = self.dropout(self.activation(self.linear1(y))) + y = self.dropout(self.linear2(y)) + + return self.norm2(x+y) + + +class TransformerEncoder(Module): + """TransformerEncoder is little more than a sequence of transformer encoder + layers. + + It contains an optional final normalization layer as well as the ability to + create the masks once and save some computation. + + Arguments + --------- + layers: list, TransformerEncoderLayer instances or instances that + implement the same interface. + norm_layer: A normalization layer to be applied to the final output + (default: None which means no normalization) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, layers, norm_layer=None, event_dispatcher=""): + super(TransformerEncoder, self).__init__() + self.layers = ModuleList(layers) + self.norm = norm_layer + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, attn_mask=None, length_mask=None): + """Apply all transformer encoder layers to the input x. + + Arguments + --------- + x: The input features of shape (N, L, E) where N is the batch size, + L is the sequence length (padded) and E is d_model passed in the + constructor of each transformer encoder layer. + attn_mask: An implementation of fast_transformers.masking.BaseMask + that encodes where each element of x can attend to. + length_mask: An implementation of + fast_transformers.masking.BaseMask that encodes how + many elements each sequence in the batch consists of. + """ + # Normalize the masks + N = x.shape[0] + L = x.shape[1] + attn_mask = attn_mask or FullMask(L, device=x.device) + length_mask = length_mask or \ + LengthMask(x.new_full((N,), L, dtype=torch.int64)) + + # Apply all the transformers + for layer in self.layers: + x = layer(x, attn_mask=attn_mask, length_mask=length_mask) + + # Apply the normalization if needed + if self.norm is not None: + x = self.norm(x) + + return x + + +class TransformerDecoderLayer(Module): + """The decoder layer from "Attention Is All You Need". + + Similar to the encoder layer, this layer implements the decoder that + PyTorch implements but can be used with any attention implementation + because it receives the attention layers as constructor arguments. + + Arguments + --------- + self_attention: The attention implementation to use for self attention + given as a nn.Module + cross_attention: The attention implementation to use for cross + attention given as a nn.Module + d_model: The input feature dimensionality + d_ff: The dimensionality of the intermediate features after the + attention (default: d_model*4) + dropout: The dropout rate to apply to the intermediate features + (default: 0.1) + activation: {'relu', 'gelu'} Which activation to use for the feed + forward part of the layer (default: relu) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, + dropout=0.1, activation="relu", event_dispatcher=""): + super(TransformerDecoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.linear1 = Linear(d_model, d_ff) + self.linear2 = Linear(d_ff, d_model) + self.norm1 = LayerNorm(d_model) + self.norm2 = LayerNorm(d_model) + self.norm3 = LayerNorm(d_model) + self.dropout = Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, memory, x_mask=None, x_length_mask=None, + memory_mask=None, memory_length_mask=None): + """Apply the transformer decoder to the input x using the memory + `memory`. + + Arguments + --------- + x: The input features of shape (N, L, E) where N is the batch size, + L is the sequence length (padded) and E should be the same as + the d_model passed in the constructor. + memory: The memory features of shape (N, L', E) where N is the + batch size, L' is the memory's sequence length (padded) and + E should be the same as the d_model. + x_mask: An implementation of fast_transformers.masking.BaseMask + that encodes where each element of x can attend to in x. + Namely the self attention mask. + x_length_mask: An implementation of a BaseMask that encodes how + many elements each sequence in the batch consists + of. + memory_mask: An implementation of BaseMask that encodes where each + element of x can attend to in the memory. Namely the + cross attention mask. + memory_length_mask: An implementation of a BaseMask that encodes how + many elements each memory sequence in the batch + consists of. + """ + # Normalize the masks + N = x.shape[0] + L = x.shape[1] + L_prime = memory.shape[1] + x_mask = x_mask or FullMask(L, device=x.device) + x_length_mask = x_length_mask or \ + LengthMask(x.new_full((N,), L, dtype=torch.int64)) + memory_mask = memory_mask or FullMask(L, L_prime, device=x.device) + memory_length_mask = memory_length_mask or \ + LengthMask(x.new_full((N,), L_prime, dtype=torch.int64)) + + # First apply the self attention and add it to the input + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask, + query_lengths=x_length_mask, + key_lengths=x_length_mask + )) + x = self.norm1(x) + + # Secondly apply the cross attention and add it to the previous output + x = x + self.dropout(self.cross_attention( + x, memory, memory, + attn_mask=memory_mask, + query_lengths=x_length_mask, + key_lengths=memory_length_mask + )) + + # Finally run the fully connected part of the layer + y = x = self.norm2(x) + y = self.dropout(self.activation(self.linear1(y))) + y = self.dropout(self.linear2(y)) + + return self.norm3(x+y) + + +class TransformerDecoder(Module): + """TransformerDecoder is little more than a sequence of transformer decoder + layers. + + It contains an optional final normalization layer as well as the ability to + create the masks once and save some computation. + + Arguments + ---------- + layers: list, TransformerDecoderLayer instances or instances that + implement the same interface + norm_layer: A normalization layer to be applied to the final output + (default: None which means no normalization) + event_dispatcher: str or EventDispatcher instance to be used by this + module for dispatching events (default: the default + global dispatcher) + """ + def __init__(self, layers, norm_layer=None, event_dispatcher=""): + super(TransformerDecoder, self).__init__() + self.layers = ModuleList(layers) + self.norm = norm_layer + self.event_dispatcher = EventDispatcher.get(event_dispatcher) + + def forward(self, x, memory, x_mask=None, x_length_mask=None, + memory_mask=None, memory_length_mask=None): + # Normalize the masks + N = x.shape[0] + L = x.shape[1] + L_prime = memory.shape[1] + x_mask = x_mask or FullMask(L, device=x.device) + x_length_mask = x_length_mask or \ + LengthMask(x.new_full((N,), L, dtype=torch.int64)) + memory_mask = memory_mask or FullMask(L, L_prime, device=x.device) + memory_length_mask = memory_length_mask or \ + LengthMask(x.new_full((N,), L_prime, dtype=torch.int64)) + + # Apply all the transformer decoders + for layer in self.layers: + x = layer(x, memory, x_mask=x_mask, x_length_mask=x_length_mask, + memory_mask=memory_mask, + memory_length_mask=memory_length_mask) + + # Apply the normalization if needed + if self.norm is not None: + x = self.norm(x) + + return x diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/utils.py b/smi-ted/inference/smi_ted_light/fast_transformers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4be2fd15b54fa5dbc9d504b4fe783c3fa58e69 --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/utils.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos +# + +"""Boilerplate code for dealing with fast_transformers modules.""" + + +def make_mirror(src_module, dst_module): + """Sets the parameters of src_module to dst_module so that they share the + same parameters. + + Most noteable usecase is to make a recurrent transformer mirror of a batch + transformer for fast inference. + + Arguments + --------- + src_module: Module to take the parameters from + dst_module: Module to set the parameters to + + Returns + ------- + None, it changes dst_module in place + """ + def setattr_recursive(mod, key, value): + key, *next_key = key.split(".", maxsplit=1) + if not next_key: + setattr(mod, key, value) + else: + setattr_recursive(getattr(mod, key), next_key[0], value) + + for name, param in src_module.named_parameters(): + setattr_recursive(dst_module, name, param) diff --git a/smi-ted/inference/smi_ted_light/fast_transformers/weight_mapper.py b/smi-ted/inference/smi_ted_light/fast_transformers/weight_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dfca47c6a3e2e03d79cd9e65f8f6fcda51e07f --- /dev/null +++ b/smi-ted/inference/smi_ted_light/fast_transformers/weight_mapper.py @@ -0,0 +1,273 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +"""The weight mapper module provides a utility to load transformer model +weights from other implementations to a fast_transformers model. + +NOTE: This API is lkely to change in the future as we collect more information + regarding how people use it. +""" + +import re + + +class MappingRule(object): + """A mapping rule can be applied to a key and value and it returns new keys + and values to be added in the state dict.""" + def matches(self, key): + """Check whether this mapping rule should be applied to this key.""" + raise NotImplementedError() + + def apply(self, key, value): + """Apply the rule and map the key to a new one.""" + raise NotImplementedError() + + +class IdentityRule(MappingRule): + """The identity rule matches all keys and returns them as is.""" + def matches(self, key): + return True + + def apply(self, key, value): + return [(key, value)] + + +class NotRule(MappingRule): + """Decorate a MappingRule by using a logical not for the matches function + and identity for the apply.""" + def __init__(self, rule): + self.rule = rule + + def matches(self, key): + return not self.rule.matches(key) + + def apply(self, key, value): + return [(key, value)] + +class OrRule(MappingRule): + """Decorate some MappingRules using the logical or to create a matches + function that returns True if any of the rules matches. In case of a match + apply all of the rules.""" + def __init__(self, *rules): + self.rules = rules + + def matches(self, key): + return any(r.matches(key) for r in self.rules) + + def apply(self, key, value): + items = [(key, value)] + for r in self.rules: + items = [ + r.apply(k, v) + for k, v in items + ] + return items + + +class RegexRule(MappingRule): + """Apply a regex search and replace on a key. + + Arguments + --------- + search: str, the regex pattern to search and replace + replace: str or callable, the replacement for every occurence of the + search pattern. If it is a callable it should follow the rules + of python re.sub(). + """ + def __init__(self, search, replace): + self.pattern = re.compile(search) + self.replace = replace + + def matches(self, key): + return self.pattern.search(key) is not None + + def apply(self, key, value): + return [(self.pattern.sub(self.replace, key), value)] + + +class PytorchAttentionWeightsRule(MappingRule): + """Map the merged MultiheadAttention weights to the corresponding keys and + values.""" + def __init__(self): + self.weight_pattern = "self_attn.in_proj_weight" + self.bias_pattern = "self_attn.in_proj_bias" + + def matches(self, key): + return ( + self.weight_pattern in key or + self.bias_pattern in key + ) + + def apply(self, key, value): + N = value.shape[0] + if self.weight_pattern in key: + return [ + ( + key.replace( + self.weight_pattern, + "attention.query_projection.weight" + ), + value[:N//3] + ), + ( + key.replace( + self.weight_pattern, + "attention.key_projection.weight" + ), + value[N//3:2*N//3] + ), + ( + key.replace( + self.weight_pattern, + "attention.value_projection.weight" + ), + value[2*N//3:] + ) + ] + if self.bias_pattern in key: + return [ + ( + key.replace( + self.bias_pattern, + "attention.query_projection.bias" + ), + value[:N//3] + ), + ( + key.replace( + self.bias_pattern, + "attention.key_projection.bias" + ), + value[N//3:2*N//3] + ), + ( + key.replace( + self.bias_pattern, + "attention.value_projection.bias" + ), + value[2*N//3:] + ) + ] + + +class SimpleMapper(object): + """Map keys of a state dict to other keys. + + Arguments + --------- + rules: A list of mapping rules to apply to the keys (default: []). + add_identity: bool, if set to True add a catch all identity rule as the + final rule (default: True). + """ + def __init__(self, rules=[], add_identity=True): + self._rules = rules + if add_identity: + self._rules.append(IdentityRule()) + + def map(self, state_dict): + new_state = {} + for k, v in state_dict.items(): + for rule in self._rules: + if rule.matches(k): + for nk, nv in rule.apply(k, v): + new_state[nk] = nv + break + return new_state + + @classmethod + def load_file(cls, filepath, model_root=None, map_location=None, + **other_args): + """Load the file and apply the weight map. + + The model root the key that contains the state dict to be mapped. + + Arguments + --------- + filepath: The file containing the saved state. + model_root: The key for the state dict to be mapped, if None assume + it is the top level dictionary (default: None). + map_location: The parameter is passed to torch.load . + other_args: The parameter dict is passed to torch.load because it + expects a similar dictionary of arguments to pass to + pickle.load. + """ + state = torch.load(filepath, map_location=map_location, **other_args) + if model_root is None: + state = cls().map(state) + else: + state[model_root] = cls().map(state[model_root]) + + return state + + +class PytorchMapper(SimpleMapper): + """Map a Pytorch transformer encoder state dict to a fast transformers + one.""" + def __init__(self): + super(PytorchMapper, self).__init__([ + PytorchAttentionWeightsRule(), + RegexRule( + r"layers\.(\d+)\.self_attn\.([a-z]+)_proj(ection)?\.", + r"layers.\1.attention.\2_projection." + ), + NotRule(OrRule( + RegexRule( + r"\.softmax_temp$", + r"" + ) + )) + ], add_identity=False) + + +class HugginfaceBertEncoderMapper(SimpleMapper): + """Map the weights of a model that uses a BertEncoder to our fast + transformers.""" + RULES = [ + RegexRule( + r"layer\.(\d+)\.attention\.self\.(query|key|value)", + r"layers.\1.attention.\2_projection" + ), + RegexRule( + r"layer\.(\d+)\.attention\.output\.dense", + r"layers.\1.attention.out_projection" + ), + RegexRule( + r"layer\.(\d+)\.attention\.output\.LayerNorm", + r"layers.\1.norm1" + ), + RegexRule( + r"layer\.(\d+)\.intermediate\.dense", + r"layers.\1.linear1" + ), + RegexRule( + r"layer\.(\d+)\.output\.dense", + r"layers.\1.linear2" + ), + RegexRule( + r"layer\.(\d+)\.output\.LayerNorm", + r"layers.\1.norm2" + ) + ] + + def __init__(self): + super(HugginfaceBertEncoderMapper, self).__init__(self.RULES) + + +class LongformerMapper(SimpleMapper): + """Map the longformer weights to our fast transformers. + + NOTE: The projections for the global attention are ignored. + """ + def __init__(self): + super(LongformerMapper, self).__init__( + HugginfaceBertEncoderMapper.RULES + [ + NotRule(RegexRule( + r"layer\.(\d+)\.attention\.self\.(query|key|value)_global", + "" + )) + ], + add_identity=False + ) diff --git a/smi-ted/inference/smi_ted_light/load.py b/smi-ted/inference/smi_ted_light/load.py index b82fd26c6fbeb8746e198a7ef0ecdaad4f384410..095d5a09fb0f865217c11dc6804c1a6774af6b18 100644 --- a/smi-ted/inference/smi_ted_light/load.py +++ b/smi-ted/inference/smi_ted_light/load.py @@ -6,13 +6,13 @@ import torch.nn.functional as F import torch.backends.cudnn as cudnn # Transformers -from fast_transformers.attention import AttentionLayer -from fast_transformers.events import QKVEvent -from fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer -from fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder -from fast_transformers.builders.attention_builders import AttentionBuilder -from fast_transformers.feature_maps import GeneralizedRandomFeatures -from fast_transformers.masking import LengthMask +from .fast_transformers.attention import AttentionLayer +from .fast_transformers.events import QKVEvent +from .fast_transformers.transformers import TransformerEncoder, TransformerEncoderLayer +from .fast_transformers.builders.transformer_builders import BaseTransformerEncoderBuilder +from .fast_transformers.builders.attention_builders import AttentionBuilder +from .fast_transformers.feature_maps import GeneralizedRandomFeatures +from .fast_transformers.masking import LengthMask from transformers import BertTokenizer from huggingface_hub import hf_hub_download