Spaces:
Sleeping
Sleeping
File size: 785 Bytes
fdc4786 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# -*- coding: utf-8 -*-
# @Time : 2022/2/17 6:05 下午
# @Author : JianingWang
# @File : loss
import torch
from torch import nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
"""Multi-class Focal loss implementation"""
def __init__(self, gamma=2, weight=None, ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
logpt = (1 - pt) ** self.gamma * logpt
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
return loss
|