Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The IDEA Authors. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from torch.nn import functional as F | |
class FocalLoss(torch.nn.Module): | |
"""Multi-class Focal loss implementation""" | |
def __init__(self, gamma=2, weight=None, ignore_index=-100): | |
super(FocalLoss, self).__init__() | |
self.gamma = gamma | |
self.weight = weight | |
self.ignore_index = ignore_index | |
def forward(self, input, target): | |
""" | |
input: [N, C] | |
target: [N, ] | |
""" | |
logpt = F.log_softmax(input, dim=1) | |
pt = torch.exp(logpt) | |
logpt = (1-pt)**self.gamma * logpt | |
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) | |
return loss | |
# 交叉熵平滑滤波 防止过拟合 | |
class LabelSmoothingCorrectionCrossEntropy(torch.nn.Module): | |
def __init__(self, eps=0.1, reduction='mean', ignore_index=-100): | |
super(LabelSmoothingCorrectionCrossEntropy, self).__init__() | |
self.eps = eps | |
self.reduction = reduction | |
self.ignore_index = ignore_index | |
def forward(self, output, target): | |
c = output.size()[-1] | |
log_preds = F.log_softmax(output, dim=-1) | |
if self.reduction == 'sum': | |
loss = -log_preds.sum() | |
else: | |
loss = -log_preds.sum(dim=-1) | |
if self.reduction == 'mean': | |
loss = loss.mean() | |
# task specific | |
labels_hat = torch.argmax(output, dim=1) | |
lt_sum = labels_hat + target | |
abs_lt_sub = abs(labels_hat - target) | |
correction_loss = 0 | |
for i in range(c): | |
if lt_sum[i] == 0: | |
pass | |
elif lt_sum[i] == 1: | |
if abs_lt_sub[i] == 1: | |
pass | |
else: | |
correction_loss -= self.eps*(0.5945275813408382) | |
else: | |
correction_loss += self.eps*(1/0.32447699714575207) | |
correction_loss /= c | |
# print(correction_loss) | |
return loss*self.eps/c + (1-self.eps) * \ | |
F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index) + correction_loss | |