mask-detection / src /loss.py
eksemyashkina's picture
Upload 13 files
f514e23 verified
raw
history blame
8.77 kB
from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
def box_iou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
N = boxes1.size(0)
M = boxes2.size(0)
x1_1, y1_1, x2_1, y2_1 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3]
x1_2, y1_2, x2_2, y2_2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2], boxes2[:, 3]
x1_1 = x1_1.unsqueeze(1).expand(N, M)
y1_1 = y1_1.unsqueeze(1).expand(N, M)
x2_1 = x2_1.unsqueeze(1).expand(N, M)
y2_1 = y2_1.unsqueeze(1).expand(N, M)
x1_2 = x1_2.unsqueeze(0).expand(N, M)
y1_2 = y1_2.unsqueeze(0).expand(N, M)
x2_2 = x2_2.unsqueeze(0).expand(N, M)
y2_2 = y2_2.unsqueeze(0).expand(N, M)
interX1 = torch.max(x1_1, x1_2)
interY1 = torch.max(y1_1, y1_2)
interX2 = torch.min(x2_1, x2_2)
interY2 = torch.min(y2_1, y2_2)
interW = (interX2 - interX1).clamp(min=0)
interH = (interY2 - interY1).clamp(min=0)
interArea = interW * interH
area1 = (x2_1 - x1_1).clamp(min=0) * (y2_1 - y1_1).clamp(min=0)
area2 = (x2_2 - x1_2).clamp(min=0) * (y2_2 - y1_2).clamp(min=0)
union = area1 + area2 - interArea + 1e-16
iou = interArea / union
return iou
def box_giou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
xA = torch.max(boxes1[:, 0], boxes2[:, 0])
yA = torch.max(boxes1[:, 1], boxes2[:, 1])
xB = torch.min(boxes1[:, 2], boxes2[:, 2])
yB = torch.min(boxes1[:, 3], boxes2[:, 3])
interW = (xB - xA).clamp(min=0)
interH = (yB - yA).clamp(min=0)
interArea = interW * interH
area1 = (boxes1[:, 2] - boxes1[:, 0]).clamp(min=0) * (boxes1[:, 3] - boxes1[:, 1]).clamp(min=0)
area2 = (boxes2[:, 2] - boxes2[:, 0]).clamp(min=0) * (boxes2[:, 3] - boxes2[:, 1]).clamp(min=0)
union = area1 + area2 - interArea + 1e-16
iou = interArea / union
xC1 = torch.min(boxes1[:, 0], boxes2[:, 0])
yC1 = torch.min(boxes1[:, 1], boxes2[:, 1])
xC2 = torch.max(boxes1[:, 2], boxes2[:, 2])
yC2 = torch.max(boxes1[:, 3], boxes2[:, 3])
encloseW = (xC2 - xC1).clamp(min=0)
encloseH = (yC2 - yC1).clamp(min=0)
encloseArea = encloseW * encloseH + 1e-16
giou = iou - (encloseArea - union) / encloseArea
return giou
class YoloLoss(nn.Module):
def __init__(self, class_counts: List[int], anchors_l: List[int] = [(128, 152), (182, 205), (103, 124)], anchors_m: List[int] = [(78, 88), (55, 59), (37, 42)], anchors_s: List[int] = [(26, 28), (17, 19), (10, 11)], image_size: Tuple[int] = (416, 416), num_classes: int = 3, ignore_thresh: float = 0.7, lambda_noobj: float = 5.0):
super().__init__()
self.anchors_l = anchors_l
self.anchors_m = anchors_m
self.anchors_s = anchors_s
self.image_size = image_size
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.lambda_noobj = lambda_noobj
total = sum(class_counts)
w_list = [total / (c + 1e-5) * (2.0 if c_id == 0 else (3.0 if c_id == 2 else 1.0)) for c_id, c in enumerate(class_counts)]
self.class_weight = torch.tensor(w_list, dtype=torch.float32)
self.bce_obj = nn.BCEWithLogitsLoss(reduction="none")
self.bce_cls = nn.BCEWithLogitsLoss(weight=self.class_weight, reduction="none")
def forward(self, outputs: Tuple[torch.Tensor], targets: Tuple[torch.Tensor]) -> torch.Tensor:
out_l, out_m, out_s = outputs
t_l, t_m, t_s = targets
loss_l = self._loss_single_scale(out_l, t_l, self.anchors_l, scale_wh=(13, 13))
loss_m = self._loss_single_scale(out_m, t_m, self.anchors_m, scale_wh=(26, 26))
loss_s = self._loss_single_scale(out_s, t_s, self.anchors_s, scale_wh=(52, 52))
return loss_l + loss_m + loss_s
def _loss_single_scale(self, pred: torch.Tensor, target: torch.Tensor, anchors: List[Tuple[int]], scale_wh: Tuple[int]) -> torch.Tensor:
device = pred.device
B, _, H, W = pred.shape
A = len(anchors)
pred = pred.view(B, A, (5 + self.num_classes), H, W)
pred = pred.permute(0, 3, 4, 1, 2).contiguous()
pred_tx = pred[..., 0]
pred_ty = pred[..., 1]
pred_tw = pred[..., 2]
pred_th = pred[..., 3]
pred_obj = pred[..., 4]
pred_cls = pred[..., 5:]
tgt_tx = target[..., 0]
tgt_ty = target[..., 1]
tgt_tw = target[..., 2]
tgt_th = target[..., 3]
tgt_obj = target[..., 4]
tgt_cls = target[..., 5:]
obj_mask = (tgt_obj == 1)
noobj_mask = (tgt_obj == 0)
img_w, img_h = self.image_size
stride_x = img_w / W
stride_y = img_h / H
grid_x = torch.arange(W, device=device).view(1, 1, W, 1).expand(1, H, W, 1)
grid_y = torch.arange(H, device=device).view(1, H, 1, 1).expand(1, H, W, 1)
anchors_t = torch.tensor(anchors, dtype=torch.float, device=device)
anchor_w = anchors_t[:, 0].view(1, 1, 1, A)
anchor_h = anchors_t[:, 1].view(1, 1, 1, A)
pred_box_xc = (grid_x + torch.sigmoid(pred_tx)) * stride_x
pred_box_yc = (grid_y + torch.sigmoid(pred_ty)) * stride_y
pred_box_w = torch.exp(pred_tw) * anchor_w
pred_box_h = torch.exp(pred_th) * anchor_h
pred_x1 = pred_box_xc - pred_box_w / 2
pred_y1 = pred_box_yc - pred_box_h / 2
pred_x2 = pred_box_xc + pred_box_w / 2
pred_y2 = pred_box_yc + pred_box_h / 2
gt_box_xc = (grid_x + tgt_tx) * stride_x
gt_box_yc = (grid_y + tgt_ty) * stride_y
gt_box_w = torch.exp(tgt_tw) * anchor_w
gt_box_h = torch.exp(tgt_th) * anchor_h
gt_x1 = gt_box_xc - gt_box_w / 2
gt_y1 = gt_box_yc - gt_box_h /2
gt_x2 = gt_box_xc + gt_box_w / 2
gt_y2 = gt_box_yc + gt_box_h / 2
with torch.no_grad():
ignore_mask_buf = torch.zeros_like(tgt_obj, dtype=torch.bool)
noobj_flat = noobj_mask.view(-1)
obj_flat = obj_mask.view(-1)
px1f = pred_x1.view(-1)
py1f = pred_y1.view(-1)
px2f = pred_x2.view(-1)
py2f = pred_y2.view(-1)
gx1f = gt_x1.view(-1)[obj_flat]
gy1f = gt_y1.view(-1)[obj_flat]
gx2f = gt_x2.view(-1)[obj_flat]
gy2f = gt_y2.view(-1)[obj_flat]
if noobj_flat.sum() > 0 and obj_flat.sum() > 0:
noobj_idx = noobj_flat.nonzero(as_tuple=True)[0]
noobj_boxes_xyxy = torch.stack([px1f[noobj_idx], py1f[noobj_idx], px2f[noobj_idx], py2f[noobj_idx]], dim=-1)
obj_boxes_xyxy = torch.stack([gx1f, gy1f, gx2f, gy2f], dim=-1)
ious = box_iou_xyxy(noobj_boxes_xyxy, obj_boxes_xyxy)
best_iou, _ = ious.max(dim=1)
ignore_flags = (best_iou > self.ignore_thresh)
all_idx = noobj_idx[ignore_flags]
ignore_mask_buf.view(-1)[all_idx] = True
ignore_mask = ignore_mask_buf
obj_loss = self.bce_obj(pred_obj[obj_mask], torch.ones_like(pred_obj[obj_mask]))
obj_loss = obj_loss.mean() if obj_loss.numel() > 0 else torch.tensor(0., device=device)
noobj_mask_final = (noobj_mask & (~ignore_mask))
noobj_loss = self.bce_obj(pred_obj[noobj_mask_final], torch.zeros_like(pred_obj[noobj_mask_final]))
noobj_loss = noobj_loss.mean() if noobj_loss.numel() > 0 else torch.tensor(0., device=device)
objectness_loss = obj_loss + self.lambda_noobj * noobj_loss
class_loss = torch.tensor(0., device=device, requires_grad=True)
if obj_mask.sum() > 0:
self.bce_cls.weight = self.class_weight.to(device)
cls_pred = pred_cls[obj_mask].to(device)
cls_gt = tgt_cls[obj_mask].to(device)
c_loss = self.bce_cls(cls_pred, cls_gt)
class_loss = c_loss.mean()
giou_loss = torch.tensor(0., device=device, requires_grad=True)
if obj_mask.sum() > 0:
px1_ = pred_x1[obj_mask]
py1_ = pred_y1[obj_mask]
px2_ = pred_x2[obj_mask]
py2_ = pred_y2[obj_mask]
p_xyxy = torch.stack([px1_,py1_,px2_,py2_], dim=-1)
gx1_ = gt_x1[obj_mask]
gy1_ = gt_y1[obj_mask]
gx2_ = gt_x2[obj_mask]
gy2_ = gt_y2[obj_mask]
g_xyxy = torch.stack([gx1_,gy1_,gx2_,gy2_], dim=-1)
giou = box_giou_xyxy(p_xyxy, g_xyxy)
giou_loss = (1. - giou).mean()
total_loss = objectness_loss + class_loss + giou_loss
if total_loss is None:
pass
return total_loss