File size: 10,055 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
""" Embeddings module """
import math
import warnings
import torch
import torch.nn as nn
from onmt.modules.util_class import Elementwise
class SequenceTooLongError(Exception):
pass
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on "Attention Is All You Need"
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
Args:
dropout (float): dropout parameter
dim (int): embedding size
"""
def __init__(self, dropout, dim, max_len=5000):
if dim % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(dim))
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
-(math.log(10000.0) / dim)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(1)
super(PositionalEncoding, self).__init__()
self.register_buffer('pe', pe)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
def forward(self, emb, step=None):
"""Embed inputs.
Args:
emb (FloatTensor): Sequence of word vectors
``(seq_len, batch_size, self.dim)``
step (int or NoneType): If stepwise (``seq_len = 1``), use
the encoding for this position.
"""
emb = emb * math.sqrt(self.dim)
step = step or 0
if self.pe.size(0) < step + emb.size(0):
raise SequenceTooLongError(
f"Sequence is {emb.size(0) + step} but PositionalEncoding is"
f" limited to {self.pe.size(0)}. See max_len argument."
)
emb = emb + self.pe[step:emb.size(0)+step]
emb = self.dropout(emb)
return emb
class Embeddings(nn.Module):
"""Words embeddings for encoder/decoder.
Additionally includes ability to add sparse input features
based on "Linguistic Input Features Improve Neural Machine Translation"
:cite:`sennrich2016linguistic`.
.. mermaid::
graph LR
A[Input]
C[Feature 1 Lookup]
A-->B[Word Lookup]
A-->C
A-->D[Feature N Lookup]
B-->E[MLP/Concat]
C-->E
D-->E
E-->F[Output]
Args:
word_vec_size (int): size of the dictionary of embeddings.
word_padding_idx (int): padding index for words in the embeddings.
feat_padding_idx (List[int]): padding index for a list of features
in the embeddings.
word_vocab_size (int): size of dictionary of embeddings for words.
feat_vocab_sizes (List[int], optional): list of size of dictionary
of embeddings for each feature.
position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding`
feat_merge (string): merge action for the features embeddings:
concat, sum or mlp.
feat_vec_exponent (float): when using `-feat_merge concat`, feature
embedding size is N^feat_dim_exponent, where N is the
number of values the feature takes.
feat_vec_size (int): embedding dimension for features when using
`-feat_merge mlp`
dropout (float): dropout probability.
freeze_word_vecs (bool): freeze weights of word vectors.
"""
def __init__(self, word_vec_size,
word_vocab_size,
word_padding_idx,
position_encoding=False,
feat_merge="concat",
feat_vec_exponent=0.7,
feat_vec_size=-1,
feat_padding_idx=[],
feat_vocab_sizes=[],
dropout=0,
sparse=False,
freeze_word_vecs=False):
self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent,
feat_vec_size, feat_padding_idx)
if feat_padding_idx is None:
feat_padding_idx = []
self.word_padding_idx = word_padding_idx
self.word_vec_size = word_vec_size
# Dimensions and padding for constructing the word embedding matrix
vocab_sizes = [word_vocab_size]
emb_dims = [word_vec_size]
pad_indices = [word_padding_idx]
# Dimensions and padding for feature embedding matrices
# (these have no effect if feat_vocab_sizes is empty)
if feat_merge == 'sum':
feat_dims = [word_vec_size] * len(feat_vocab_sizes)
elif feat_vec_size > 0:
feat_dims = [feat_vec_size] * len(feat_vocab_sizes)
else:
feat_dims = [int(vocab ** feat_vec_exponent)
for vocab in feat_vocab_sizes]
vocab_sizes.extend(feat_vocab_sizes)
emb_dims.extend(feat_dims)
pad_indices.extend(feat_padding_idx)
# The embedding matrix look-up tables. The first look-up table
# is for words. Subsequent ones are for features, if any exist.
emb_params = zip(vocab_sizes, emb_dims, pad_indices)
embeddings = [nn.Embedding(vocab, dim, padding_idx=pad, sparse=sparse)
for vocab, dim, pad in emb_params]
emb_luts = Elementwise(feat_merge, embeddings)
# The final output size of word + feature vectors. This can vary
# from the word vector size if and only if features are defined.
# This is the attribute you should access if you need to know
# how big your embeddings are going to be.
self.embedding_size = (sum(emb_dims) if feat_merge == 'concat'
else word_vec_size)
# The sequence of operations that converts the input sequence
# into a sequence of embeddings. At minimum this consists of
# looking up the embeddings for each word and feature in the
# input. Model parameters may require the sequence to contain
# additional operations as well.
super(Embeddings, self).__init__()
self.make_embedding = nn.Sequential()
self.make_embedding.add_module('emb_luts', emb_luts)
if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0:
in_dim = sum(emb_dims)
mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU())
self.make_embedding.add_module('mlp', mlp)
self.position_encoding = position_encoding
if self.position_encoding:
pe = PositionalEncoding(dropout, self.embedding_size)
self.make_embedding.add_module('pe', pe)
if freeze_word_vecs:
self.word_lut.weight.requires_grad = False
def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent,
feat_vec_size, feat_padding_idx):
if feat_merge == "sum":
# features must use word_vec_size
if feat_vec_exponent != 0.7:
warnings.warn("Merging with sum, but got non-default "
"feat_vec_exponent. It will be unused.")
if feat_vec_size != -1:
warnings.warn("Merging with sum, but got non-default "
"feat_vec_size. It will be unused.")
elif feat_vec_size > 0:
# features will use feat_vec_size
if feat_vec_exponent != -1:
warnings.warn("Not merging with sum and positive "
"feat_vec_size, but got non-default "
"feat_vec_exponent. It will be unused.")
else:
if feat_vec_exponent <= 0:
raise ValueError("Using feat_vec_exponent to determine "
"feature vec size, but got feat_vec_exponent "
"less than or equal to 0.")
n_feats = len(feat_vocab_sizes)
if n_feats != len(feat_padding_idx):
raise ValueError("Got unequal number of feat_vocab_sizes and "
"feat_padding_idx ({:d} != {:d})".format(
n_feats, len(feat_padding_idx)))
@property
def word_lut(self):
"""Word look-up table."""
return self.make_embedding[0][0]
@property
def emb_luts(self):
"""Embedding look-up table."""
return self.make_embedding[0]
def load_pretrained_vectors(self, emb_file):
"""Load in pretrained embeddings.
Args:
emb_file (str) : path to torch serialized embeddings
"""
if emb_file:
pretrained = torch.load(emb_file)
pretrained_vec_size = pretrained.size(1)
if self.word_vec_size > pretrained_vec_size:
self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained
elif self.word_vec_size < pretrained_vec_size:
self.word_lut.weight.data \
.copy_(pretrained[:, :self.word_vec_size])
else:
self.word_lut.weight.data.copy_(pretrained)
def forward(self, source, step=None):
"""Computes the embeddings for words and features.
Args:
source (LongTensor): index tensor ``(len, batch, nfeat)``
Returns:
FloatTensor: Word embeddings ``(len, batch, embedding_size)``
"""
if self.position_encoding:
for i, module in enumerate(self.make_embedding._modules.values()):
if i == len(self.make_embedding._modules.values()) - 1:
source = module(source, step=step)
else:
source = module(source)
else:
source = self.make_embedding(source)
return source
def update_dropout(self, dropout):
if self.position_encoding:
self._modules['make_embedding'][1].dropout.p = dropout
|