Spaces:
Sleeping
Sleeping
""" Embeddings module """ | |
import math | |
import warnings | |
import torch | |
import torch.nn as nn | |
from onmt.modules.util_class import Elementwise | |
class SequenceTooLongError(Exception): | |
pass | |
class PositionalEncoding(nn.Module): | |
"""Sinusoidal positional encoding for non-recurrent neural networks. | |
Implementation based on "Attention Is All You Need" | |
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17` | |
Args: | |
dropout (float): dropout parameter | |
dim (int): embedding size | |
""" | |
def __init__(self, dropout, dim, max_len=5000): | |
if dim % 2 != 0: | |
raise ValueError("Cannot use sin/cos positional encoding with " | |
"odd dim (got dim={:d})".format(dim)) | |
pe = torch.zeros(max_len, dim) | |
position = torch.arange(0, max_len).unsqueeze(1) | |
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * | |
-(math.log(10000.0) / dim))) | |
pe[:, 0::2] = torch.sin(position.float() * div_term) | |
pe[:, 1::2] = torch.cos(position.float() * div_term) | |
pe = pe.unsqueeze(1) | |
super(PositionalEncoding, self).__init__() | |
self.register_buffer('pe', pe) | |
self.dropout = nn.Dropout(p=dropout) | |
self.dim = dim | |
def forward(self, emb, step=None): | |
"""Embed inputs. | |
Args: | |
emb (FloatTensor): Sequence of word vectors | |
``(seq_len, batch_size, self.dim)`` | |
step (int or NoneType): If stepwise (``seq_len = 1``), use | |
the encoding for this position. | |
""" | |
emb = emb * math.sqrt(self.dim) | |
step = step or 0 | |
if self.pe.size(0) < step + emb.size(0): | |
raise SequenceTooLongError( | |
f"Sequence is {emb.size(0) + step} but PositionalEncoding is" | |
f" limited to {self.pe.size(0)}. See max_len argument." | |
) | |
emb = emb + self.pe[step:emb.size(0)+step] | |
emb = self.dropout(emb) | |
return emb | |
class Embeddings(nn.Module): | |
"""Words embeddings for encoder/decoder. | |
Additionally includes ability to add sparse input features | |
based on "Linguistic Input Features Improve Neural Machine Translation" | |
:cite:`sennrich2016linguistic`. | |
.. mermaid:: | |
graph LR | |
A[Input] | |
C[Feature 1 Lookup] | |
A-->B[Word Lookup] | |
A-->C | |
A-->D[Feature N Lookup] | |
B-->E[MLP/Concat] | |
C-->E | |
D-->E | |
E-->F[Output] | |
Args: | |
word_vec_size (int): size of the dictionary of embeddings. | |
word_padding_idx (int): padding index for words in the embeddings. | |
feat_padding_idx (List[int]): padding index for a list of features | |
in the embeddings. | |
word_vocab_size (int): size of dictionary of embeddings for words. | |
feat_vocab_sizes (List[int], optional): list of size of dictionary | |
of embeddings for each feature. | |
position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` | |
feat_merge (string): merge action for the features embeddings: | |
concat, sum or mlp. | |
feat_vec_exponent (float): when using `-feat_merge concat`, feature | |
embedding size is N^feat_dim_exponent, where N is the | |
number of values the feature takes. | |
feat_vec_size (int): embedding dimension for features when using | |
`-feat_merge mlp` | |
dropout (float): dropout probability. | |
freeze_word_vecs (bool): freeze weights of word vectors. | |
""" | |
def __init__(self, word_vec_size, | |
word_vocab_size, | |
word_padding_idx, | |
position_encoding=False, | |
feat_merge="concat", | |
feat_vec_exponent=0.7, | |
feat_vec_size=-1, | |
feat_padding_idx=[], | |
feat_vocab_sizes=[], | |
dropout=0, | |
sparse=False, | |
freeze_word_vecs=False): | |
self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent, | |
feat_vec_size, feat_padding_idx) | |
if feat_padding_idx is None: | |
feat_padding_idx = [] | |
self.word_padding_idx = word_padding_idx | |
self.word_vec_size = word_vec_size | |
# Dimensions and padding for constructing the word embedding matrix | |
vocab_sizes = [word_vocab_size] | |
emb_dims = [word_vec_size] | |
pad_indices = [word_padding_idx] | |
# Dimensions and padding for feature embedding matrices | |
# (these have no effect if feat_vocab_sizes is empty) | |
if feat_merge == 'sum': | |
feat_dims = [word_vec_size] * len(feat_vocab_sizes) | |
elif feat_vec_size > 0: | |
feat_dims = [feat_vec_size] * len(feat_vocab_sizes) | |
else: | |
feat_dims = [int(vocab ** feat_vec_exponent) | |
for vocab in feat_vocab_sizes] | |
vocab_sizes.extend(feat_vocab_sizes) | |
emb_dims.extend(feat_dims) | |
pad_indices.extend(feat_padding_idx) | |
# The embedding matrix look-up tables. The first look-up table | |
# is for words. Subsequent ones are for features, if any exist. | |
emb_params = zip(vocab_sizes, emb_dims, pad_indices) | |
embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse) | |
for vocab, dim, pad in emb_params] | |
emb_luts = Elementwise(feat_merge, embeddings) | |
# The final output size of word + feature vectors. This can vary | |
# from the word vector size if and only if features are defined. | |
# This is the attribute you should access if you need to know | |
# how big your embeddings are going to be. | |
self.embedding_size = (sum(emb_dims) if feat_merge == 'concat' | |
else word_vec_size) | |
# The sequence of operations that converts the input sequence | |
# into a sequence of embeddings. At minimum this consists of | |
# looking up the embeddings for each word and feature in the | |
# input. Model parameters may require the sequence to contain | |
# additional operations as well. | |
super(Embeddings, self).__init__() | |
self.make_embedding = nn.Sequential() | |
self.make_embedding.add_module('emb_luts', emb_luts) | |
if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0: | |
in_dim = sum(emb_dims) | |
mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) | |
self.make_embedding.add_module('mlp', mlp) | |
self.position_encoding = position_encoding | |
if self.position_encoding: | |
pe = PositionalEncoding(dropout, self.embedding_size) | |
self.make_embedding.add_module('pe', pe) | |
if freeze_word_vecs: | |
self.word_lut.weight.requires_grad = False | |
def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, | |
feat_vec_size, feat_padding_idx): | |
if feat_merge == "sum": | |
# features must use word_vec_size | |
if feat_vec_exponent != 0.7: | |
warnings.warn("Merging with sum, but got non-default " | |
"feat_vec_exponent. It will be unused.") | |
if feat_vec_size != -1: | |
warnings.warn("Merging with sum, but got non-default " | |
"feat_vec_size. It will be unused.") | |
elif feat_vec_size > 0: | |
# features will use feat_vec_size | |
if feat_vec_exponent != -1: | |
warnings.warn("Not merging with sum and positive " | |
"feat_vec_size, but got non-default " | |
"feat_vec_exponent. It will be unused.") | |
else: | |
if feat_vec_exponent <= 0: | |
raise ValueError("Using feat_vec_exponent to determine " | |
"feature vec size, but got feat_vec_exponent " | |
"less than or equal to 0.") | |
n_feats = len(feat_vocab_sizes) | |
if n_feats != len(feat_padding_idx): | |
raise ValueError("Got unequal number of feat_vocab_sizes and " | |
"feat_padding_idx ({:d} != {:d})".format( | |
n_feats, len(feat_padding_idx))) | |
def word_lut(self): | |
"""Word look-up table.""" | |
return self.make_embedding[0][0] | |
def emb_luts(self): | |
"""Embedding look-up table.""" | |
return self.make_embedding[0] | |
def load_pretrained_vectors(self, emb_file): | |
"""Load in pretrained embeddings. | |
Args: | |
emb_file (str) : path to torch serialized embeddings | |
""" | |
if emb_file: | |
pretrained = torch.load(emb_file) | |
pretrained_vec_size = pretrained.size(1) | |
if self.word_vec_size > pretrained_vec_size: | |
self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained | |
elif self.word_vec_size < pretrained_vec_size: | |
self.word_lut.weight.data \ | |
.copy_(pretrained[:, :self.word_vec_size]) | |
else: | |
self.word_lut.weight.data.copy_(pretrained) | |
def forward(self, source, step=None): | |
"""Computes the embeddings for words and features. | |
Args: | |
source (LongTensor): index tensor ``(len, batch, nfeat)`` | |
Returns: | |
FloatTensor: Word embeddings ``(len, batch, embedding_size)`` | |
""" | |
if self.position_encoding: | |
for i, module in enumerate(self.make_embedding._modules.values()): | |
if i == len(self.make_embedding._modules.values()) - 1: | |
source = module(source, step=step) | |
else: | |
source = module(source) | |
else: | |
source = self.make_embedding(source) | |
return source | |
def update_dropout(self, dropout): | |
if self.position_encoding: | |
self._modules['make_embedding'][1].dropout.p = dropout | |