Spaces:
Sleeping
Sleeping
File size: 1,821 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
class LabelSmoothingLoss(_Loss):
"""
With label smoothing,
KL-divergence between q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
"""
def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'):
assert 0.0 < label_smoothing <= 1.0
self.ignore_index = ignore_index
super(LabelSmoothingLoss, self).__init__(
size_average=size_average, reduce=reduce, reduction=reduction)
assert label_smoothing > 0
assert tgt_vocab_size > 0
smoothing_value = label_smoothing / (tgt_vocab_size - 2)
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
one_hot[self.ignore_index] = 0
self.register_buffer('one_hot', one_hot.unsqueeze(0))
self.confidence = 1.0 - label_smoothing
self.tgt_vocab_size = tgt_vocab_size
def forward(self, output, target):
"""
output (FloatTensor): batch_size * num_pos * n_classes
target (LongTensor): batch_size * num_pos
"""
assert self.tgt_vocab_size == output.size(2)
batch_size, num_pos = target.size(0), target.size(1)
output = output.view(-1, self.tgt_vocab_size)
target = target.view(-1)
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2)
|