Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.modules.loss import _Loss | |
| class LabelSmoothingLoss(_Loss): | |
| """ | |
| With label smoothing, | |
| KL-divergence between q_{smoothed ground truth prob.}(w) | |
| and p_{prob. computed by model}(w) is minimized. | |
| """ | |
| def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): | |
| assert 0.0 < label_smoothing <= 1.0 | |
| self.ignore_index = ignore_index | |
| super(LabelSmoothingLoss, self).__init__( | |
| size_average=size_average, reduce=reduce, reduction=reduction) | |
| assert label_smoothing > 0 | |
| assert tgt_vocab_size > 0 | |
| smoothing_value = label_smoothing / (tgt_vocab_size - 2) | |
| one_hot = torch.full((tgt_vocab_size,), smoothing_value) | |
| one_hot[self.ignore_index] = 0 | |
| self.register_buffer('one_hot', one_hot.unsqueeze(0)) | |
| self.confidence = 1.0 - label_smoothing | |
| self.tgt_vocab_size = tgt_vocab_size | |
| def forward(self, output, target): | |
| """ | |
| output (FloatTensor): batch_size * num_pos * n_classes | |
| target (LongTensor): batch_size * num_pos | |
| """ | |
| assert self.tgt_vocab_size == output.size(2) | |
| batch_size, num_pos = target.size(0), target.size(1) | |
| output = output.view(-1, self.tgt_vocab_size) | |
| target = target.view(-1) | |
| model_prob = self.one_hot.repeat(target.size(0), 1) | |
| model_prob.scatter_(1, target.unsqueeze(1), self.confidence) | |
| model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) | |
| return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) | |