#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import List, Tuple import torch import torch.nn as nn class DiceLoss(nn.Module): def __init__(self, reduction: str = "mean", eps: float = 1e-6, ): super(DiceLoss, self).__init__() self.reduction = reduction self.eps = eps if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") def forward(self, inputs: torch.Tensor, targets: torch.Tensor): """ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation. :param targets: shape as `inputs`. :return: """ inputs_ = torch.squeeze(inputs, dim=-1) targets_ = torch.squeeze(targets, dim=-1) # shape: [b, t] intersection = (inputs_ * targets_).sum(dim=-1) union = (inputs_ + targets_).sum(dim=-1) # shape: [b,] dice = (2. * intersection + self.eps) / (union + self.eps) # shape: [b,] loss = 1. - dice # shape: [b,] if self.reduction == "mean": loss = torch.mean(loss) elif self.reduction == "sum": loss = torch.sum(loss) else: raise AssertionError return loss def main(): inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32) loss_fn = DiceLoss() loss = loss_fn.forward(inputs, inputs) print(loss) return if __name__ == "__main__": main()