Spaces:
Sleeping
Sleeping
from collections import namedtuple | |
import torch | |
from torch import Tensor | |
from typing import List, Sequence | |
from . import Sequential, ModuleList, Linear | |
from .module import Module | |
from ..functional import log_softmax | |
__all__ = ['AdaptiveLogSoftmaxWithLoss'] | |
_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) | |
class AdaptiveLogSoftmaxWithLoss(Module): | |
r"""Efficient softmax approximation. | |
As described in | |
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, | |
Moustapha Cissé, David Grangier, and Hervé Jégou | |
<https://arxiv.org/abs/1609.04309>`__. | |
Adaptive softmax is an approximate strategy for training models with large | |
output spaces. It is most effective when the label distribution is highly | |
imbalanced, for example in natural language modelling, where the word | |
frequency distribution approximately follows the `Zipf's law`_. | |
Adaptive softmax partitions the labels into several clusters, according to | |
their frequency. These clusters may contain different number of targets | |
each. | |
Additionally, clusters containing less frequent labels assign lower | |
dimensional embeddings to those labels, which speeds up the computation. | |
For each minibatch, only clusters for which at least one target is | |
present are evaluated. | |
The idea is that the clusters which are accessed frequently | |
(like the first one, containing most frequent labels), should also be cheap | |
to compute -- that is, contain a small number of assigned labels. | |
We highly recommend taking a look at the original paper for more details. | |
* :attr:`cutoffs` should be an ordered Sequence of integers sorted | |
in the increasing order. | |
It controls number of clusters and the partitioning of targets into | |
clusters. For example setting ``cutoffs = [10, 100, 1000]`` | |
means that first `10` targets will be assigned | |
to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be | |
assigned to the first cluster, and targets `101, 102, ..., 1000` will be | |
assigned to the second cluster, while targets | |
`1001, 1002, ..., n_classes - 1` will be assigned | |
to the last, third cluster. | |
* :attr:`div_value` is used to compute the size of each additional cluster, | |
which is given as | |
:math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, | |
where :math:`idx` is the cluster index (with clusters | |
for less frequent words having larger indices, | |
and indices starting from :math:`1`). | |
* :attr:`head_bias` if set to True, adds a bias term to the 'head' of the | |
adaptive softmax. See paper for details. Set to False in the official | |
implementation. | |
.. warning:: | |
Labels passed as inputs to this module should be sorted according to | |
their frequency. This means that the most frequent label should be | |
represented by the index `0`, and the least frequent | |
label should be represented by the index `n_classes - 1`. | |
.. note:: | |
This module returns a ``NamedTuple`` with ``output`` | |
and ``loss`` fields. See further documentation for details. | |
.. note:: | |
To compute log-probabilities for all classes, the ``log_prob`` | |
method can be used. | |
Args: | |
in_features (int): Number of features in the input tensor | |
n_classes (int): Number of classes in the dataset | |
cutoffs (Sequence): Cutoffs used to assign targets to their buckets | |
div_value (float, optional): value used as an exponent to compute sizes | |
of the clusters. Default: 4.0 | |
head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the | |
adaptive softmax. Default: ``False`` | |
Returns: | |
``NamedTuple`` with ``output`` and ``loss`` fields: | |
* **output** is a Tensor of size ``N`` containing computed target | |
log probabilities for each example | |
* **loss** is a Scalar representing the computed negative | |
log likelihood loss | |
Shape: | |
- input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})` | |
- target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` | |
- output1: :math:`(N)` or :math:`()` | |
- output2: ``Scalar`` | |
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law | |
""" | |
in_features: int | |
n_classes: int | |
cutoffs: List[int] | |
div_value: float | |
head_bias: bool | |
head: Linear | |
tail: ModuleList | |
def __init__( | |
self, | |
in_features: int, | |
n_classes: int, | |
cutoffs: Sequence[int], | |
div_value: float = 4., | |
head_bias: bool = False, | |
device=None, | |
dtype=None | |
) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__() | |
cutoffs = list(cutoffs) | |
if (len(cutoffs) == 0): | |
raise ValueError("cutoffs should be a sequence of length larger than 0") | |
if (cutoffs != sorted(cutoffs)) \ | |
or (min(cutoffs) <= 0) \ | |
or (max(cutoffs) > (n_classes - 1)) \ | |
or (len(set(cutoffs)) != len(cutoffs)) \ | |
or any(int(c) != c for c in cutoffs): | |
raise ValueError("cutoffs should be a sequence of unique, positive " | |
"integers sorted in an increasing order, where " | |
"each value is between 1 and n_classes-1") | |
self.in_features = in_features | |
self.n_classes = n_classes | |
self.cutoffs = cutoffs + [n_classes] | |
self.div_value = div_value | |
self.head_bias = head_bias | |
self.shortlist_size = self.cutoffs[0] | |
self.n_clusters = len(self.cutoffs) - 1 | |
self.head_size = self.shortlist_size + self.n_clusters | |
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, | |
**factory_kwargs) | |
self.tail = ModuleList() | |
for i in range(self.n_clusters): | |
hsz = int(self.in_features // (self.div_value ** (i + 1))) | |
osz = self.cutoffs[i + 1] - self.cutoffs[i] | |
projection = Sequential( | |
Linear(self.in_features, hsz, bias=False, **factory_kwargs), | |
Linear(hsz, osz, bias=False, **factory_kwargs), | |
) | |
self.tail.append(projection) | |
def reset_parameters(self) -> None: | |
self.head.reset_parameters() | |
for i2h, h2o in self.tail: | |
i2h.reset_parameters() | |
h2o.reset_parameters() | |
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: | |
targ_dim = target_.dim() | |
if targ_dim == 1: | |
if input_.size(0) != target_.size(0): | |
raise RuntimeError('Input and target should have the same size ' | |
'in the batch dimension.') | |
if input_.dim() != 2: | |
raise RuntimeError('1D target tensor expects 2D input tensors, ' | |
'but found inputs with size', input_.size()) | |
elif targ_dim == 0: | |
if input_.dim() != 1: | |
raise RuntimeError('0D target tensor expects 1D input tensors, ' | |
'but found inputs with size', input_.size()) | |
else: | |
raise RuntimeError('0D or 1D target tensor expected, ' | |
'multi-target not supported') | |
is_batched = targ_dim > 0 | |
input = input_ if is_batched else input_.unsqueeze(0) | |
target = target_ if is_batched else target_.unsqueeze(0) | |
used_rows = 0 | |
batch_size = target.size(0) | |
output = input.new_zeros(batch_size) | |
gather_inds = target.new_empty(batch_size) | |
cutoff_values = [0] + self.cutoffs | |
for i in range(len(cutoff_values) - 1): | |
low_idx = cutoff_values[i] | |
high_idx = cutoff_values[i + 1] | |
target_mask = (target >= low_idx) & (target < high_idx) | |
row_indices = target_mask.nonzero().squeeze() | |
if row_indices.numel() == 0: | |
continue | |
if i == 0: | |
gather_inds.index_copy_(0, row_indices, target[target_mask]) | |
else: | |
relative_target = target[target_mask] - low_idx | |
input_subset = input.index_select(0, row_indices) | |
cluster_output = self.tail[i - 1](input_subset) | |
cluster_index = self.shortlist_size + i - 1 | |
gather_inds.index_fill_(0, row_indices, cluster_index) | |
cluster_logprob = log_softmax(cluster_output, dim=1) | |
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) | |
output.index_copy_(0, row_indices, local_logprob.squeeze(1)) | |
used_rows += row_indices.numel() | |
if used_rows != batch_size: | |
raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], " | |
f"but values in range [{target.min().item()}, {target.max().item()}] " | |
"were found. ") | |
head_output = self.head(input) | |
head_logprob = log_softmax(head_output, dim=1) | |
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() | |
loss = (-output).mean() | |
if not is_batched: | |
output = output.squeeze(0) | |
return _ASMoutput(output, loss) | |
def _get_full_log_prob(self, input, head_output): | |
"""Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" | |
out = input.new_empty((head_output.size(0), self.n_classes)) | |
head_logprob = log_softmax(head_output, dim=1) | |
out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] | |
for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): | |
cluster_output = self.tail[i](input) | |
cluster_logprob = log_softmax(cluster_output, dim=1) | |
output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) | |
out[:, start_idx:stop_idx] = output_logprob | |
return out | |
def log_prob(self, input: Tensor) -> Tensor: | |
r"""Compute log probabilities for all :math:`\texttt{n\_classes}`. | |
Args: | |
input (Tensor): a minibatch of examples | |
Returns: | |
log-probabilities of for each class :math:`c` | |
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a | |
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. | |
Shape: | |
- Input: :math:`(N, \texttt{in\_features})` | |
- Output: :math:`(N, \texttt{n\_classes})` | |
""" | |
head_output = self.head(input) | |
return self._get_full_log_prob(input, head_output) | |
def predict(self, input: Tensor) -> Tensor: | |
r"""Return the class with the highest probability for each example in the input minibatch. | |
This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases. | |
Args: | |
input (Tensor): a minibatch of examples | |
Returns: | |
output (Tensor): a class with the highest probability for each example | |
Shape: | |
- Input: :math:`(N, \texttt{in\_features})` | |
- Output: :math:`(N)` | |
""" | |
head_output = self.head(input) | |
output = torch.argmax(head_output, dim=1) | |
not_in_shortlist = (output >= self.shortlist_size) | |
all_in_shortlist = not (not_in_shortlist.any()) | |
if all_in_shortlist: | |
return output | |
elif not_in_shortlist.all(): | |
log_prob = self._get_full_log_prob(input, head_output) | |
return torch.argmax(log_prob, dim=1) | |
else: | |
log_prob = self._get_full_log_prob(input[not_in_shortlist], | |
head_output[not_in_shortlist]) | |
output[not_in_shortlist] = torch.argmax(log_prob, dim=1) | |
return output | |