Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| def Focal_Loss(pred, gt): | |
| # print('yes!!') | |
| ce = nn.CrossEntropyLoss() | |
| alpha = 0.25 | |
| gamma = 2 | |
| # logp = ce(input, target) | |
| p = torch.sigmoid(pred) | |
| loss = -alpha * (1 - p) ** gamma * (gt * torch.log(p)) - \ | |
| (1 - alpha) * p ** gamma * ((1 - gt) * torch.log(1 - p)) | |
| return loss.mean() | |
| # pred =torch.sigmoid(pred) | |
| # pos_inds = gt.eq(1).float() | |
| # neg_inds = gt.lt(1).float() | |
| # | |
| # loss = 0 | |
| # | |
| # pos_loss = torch.log(pred + 1e-10) * torch.pow(pred, 2) * pos_inds | |
| # # neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds | |
| # neg_loss = torch.log(1 - pred) * torch.pow(1 - pred, 2) * neg_inds | |
| # | |
| # num_pos = pos_inds.float().sum() | |
| # num_neg = neg_inds.float().sum() | |
| # | |
| # pos_loss = pos_loss.sum() | |
| # neg_loss = neg_loss.sum() | |
| # | |
| # if num_pos == 0: | |
| # loss = loss - neg_loss | |
| # else: | |
| # # loss = loss - (pos_loss + neg_loss) / (num_pos) | |
| # loss = loss - (pos_loss + neg_loss ) | |
| # return loss * 5 | |
| # if weight is not None and weight.sum() > 0: | |
| # return (losses * weight).sum() / weight.sum() | |
| # else: | |
| # assert losses.numel() != 0 | |
| # return losses.mean() |