|
|
|
|
|
|
|
|
|
|
|
from fairseq.modules import TransformerSentenceEncoderLayer |
|
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention |
|
|
|
|
|
class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): |
|
""" |
|
Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int = 768, |
|
ffn_embedding_dim: int = 3072, |
|
num_attention_heads: int = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
activation_fn: str = "relu", |
|
export: bool = False, |
|
is_bidirectional: bool = True, |
|
stride: int = 32, |
|
expressivity: int = 8, |
|
) -> None: |
|
|
|
super().__init__( |
|
embedding_dim, |
|
ffn_embedding_dim, |
|
num_attention_heads, |
|
dropout, |
|
attention_dropout, |
|
activation_dropout, |
|
activation_fn, |
|
export, |
|
) |
|
|
|
self.self_attn = SparseMultiheadAttention( |
|
self.embedding_dim, |
|
num_attention_heads, |
|
dropout=attention_dropout, |
|
add_bias_kv=False, |
|
add_zero_attn=False, |
|
self_attention=True, |
|
is_bidirectional=is_bidirectional, |
|
stride=stride, |
|
expressivity=expressivity, |
|
) |
|
|