Spaces:
Running
Running
from typing import Tuple, List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def box_iou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: | |
N = boxes1.size(0) | |
M = boxes2.size(0) | |
x1_1, y1_1, x2_1, y2_1 = boxes1[:, 0], boxes1[:, 1], boxes1[:, 2], boxes1[:, 3] | |
x1_2, y1_2, x2_2, y2_2 = boxes2[:, 0], boxes2[:, 1], boxes2[:, 2], boxes2[:, 3] | |
x1_1 = x1_1.unsqueeze(1).expand(N, M) | |
y1_1 = y1_1.unsqueeze(1).expand(N, M) | |
x2_1 = x2_1.unsqueeze(1).expand(N, M) | |
y2_1 = y2_1.unsqueeze(1).expand(N, M) | |
x1_2 = x1_2.unsqueeze(0).expand(N, M) | |
y1_2 = y1_2.unsqueeze(0).expand(N, M) | |
x2_2 = x2_2.unsqueeze(0).expand(N, M) | |
y2_2 = y2_2.unsqueeze(0).expand(N, M) | |
interX1 = torch.max(x1_1, x1_2) | |
interY1 = torch.max(y1_1, y1_2) | |
interX2 = torch.min(x2_1, x2_2) | |
interY2 = torch.min(y2_1, y2_2) | |
interW = (interX2 - interX1).clamp(min=0) | |
interH = (interY2 - interY1).clamp(min=0) | |
interArea = interW * interH | |
area1 = (x2_1 - x1_1).clamp(min=0) * (y2_1 - y1_1).clamp(min=0) | |
area2 = (x2_2 - x1_2).clamp(min=0) * (y2_2 - y1_2).clamp(min=0) | |
union = area1 + area2 - interArea + 1e-16 | |
iou = interArea / union | |
return iou | |
def box_giou_xyxy(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: | |
xA = torch.max(boxes1[:, 0], boxes2[:, 0]) | |
yA = torch.max(boxes1[:, 1], boxes2[:, 1]) | |
xB = torch.min(boxes1[:, 2], boxes2[:, 2]) | |
yB = torch.min(boxes1[:, 3], boxes2[:, 3]) | |
interW = (xB - xA).clamp(min=0) | |
interH = (yB - yA).clamp(min=0) | |
interArea = interW * interH | |
area1 = (boxes1[:, 2] - boxes1[:, 0]).clamp(min=0) * (boxes1[:, 3] - boxes1[:, 1]).clamp(min=0) | |
area2 = (boxes2[:, 2] - boxes2[:, 0]).clamp(min=0) * (boxes2[:, 3] - boxes2[:, 1]).clamp(min=0) | |
union = area1 + area2 - interArea + 1e-16 | |
iou = interArea / union | |
xC1 = torch.min(boxes1[:, 0], boxes2[:, 0]) | |
yC1 = torch.min(boxes1[:, 1], boxes2[:, 1]) | |
xC2 = torch.max(boxes1[:, 2], boxes2[:, 2]) | |
yC2 = torch.max(boxes1[:, 3], boxes2[:, 3]) | |
encloseW = (xC2 - xC1).clamp(min=0) | |
encloseH = (yC2 - yC1).clamp(min=0) | |
encloseArea = encloseW * encloseH + 1e-16 | |
giou = iou - (encloseArea - union) / encloseArea | |
return giou | |
class YoloLoss(nn.Module): | |
def __init__(self, class_counts: List[int], anchors_l: List[int] = [(128, 152), (182, 205), (103, 124)], anchors_m: List[int] = [(78, 88), (55, 59), (37, 42)], anchors_s: List[int] = [(26, 28), (17, 19), (10, 11)], image_size: Tuple[int] = (416, 416), num_classes: int = 3, ignore_thresh: float = 0.7, lambda_noobj: float = 5.0): | |
super().__init__() | |
self.anchors_l = anchors_l | |
self.anchors_m = anchors_m | |
self.anchors_s = anchors_s | |
self.image_size = image_size | |
self.num_classes = num_classes | |
self.ignore_thresh = ignore_thresh | |
self.lambda_noobj = lambda_noobj | |
total = sum(class_counts) | |
w_list = [total / (c + 1e-5) * (2.0 if c_id == 0 else (3.0 if c_id == 2 else 1.0)) for c_id, c in enumerate(class_counts)] | |
self.class_weight = torch.tensor(w_list, dtype=torch.float32) | |
self.bce_obj = nn.BCEWithLogitsLoss(reduction="none") | |
self.bce_cls = nn.BCEWithLogitsLoss(weight=self.class_weight, reduction="none") | |
def forward(self, outputs: Tuple[torch.Tensor], targets: Tuple[torch.Tensor]) -> torch.Tensor: | |
out_l, out_m, out_s = outputs | |
t_l, t_m, t_s = targets | |
loss_l = self._loss_single_scale(out_l, t_l, self.anchors_l, scale_wh=(13, 13)) | |
loss_m = self._loss_single_scale(out_m, t_m, self.anchors_m, scale_wh=(26, 26)) | |
loss_s = self._loss_single_scale(out_s, t_s, self.anchors_s, scale_wh=(52, 52)) | |
return loss_l + loss_m + loss_s | |
def _loss_single_scale(self, pred: torch.Tensor, target: torch.Tensor, anchors: List[Tuple[int]], scale_wh: Tuple[int]) -> torch.Tensor: | |
device = pred.device | |
B, _, H, W = pred.shape | |
A = len(anchors) | |
pred = pred.view(B, A, (5 + self.num_classes), H, W) | |
pred = pred.permute(0, 3, 4, 1, 2).contiguous() | |
pred_tx = pred[..., 0] | |
pred_ty = pred[..., 1] | |
pred_tw = pred[..., 2] | |
pred_th = pred[..., 3] | |
pred_obj = pred[..., 4] | |
pred_cls = pred[..., 5:] | |
tgt_tx = target[..., 0] | |
tgt_ty = target[..., 1] | |
tgt_tw = target[..., 2] | |
tgt_th = target[..., 3] | |
tgt_obj = target[..., 4] | |
tgt_cls = target[..., 5:] | |
obj_mask = (tgt_obj == 1) | |
noobj_mask = (tgt_obj == 0) | |
img_w, img_h = self.image_size | |
stride_x = img_w / W | |
stride_y = img_h / H | |
grid_x = torch.arange(W, device=device).view(1, 1, W, 1).expand(1, H, W, 1) | |
grid_y = torch.arange(H, device=device).view(1, H, 1, 1).expand(1, H, W, 1) | |
anchors_t = torch.tensor(anchors, dtype=torch.float, device=device) | |
anchor_w = anchors_t[:, 0].view(1, 1, 1, A) | |
anchor_h = anchors_t[:, 1].view(1, 1, 1, A) | |
pred_box_xc = (grid_x + torch.sigmoid(pred_tx)) * stride_x | |
pred_box_yc = (grid_y + torch.sigmoid(pred_ty)) * stride_y | |
pred_box_w = torch.exp(pred_tw) * anchor_w | |
pred_box_h = torch.exp(pred_th) * anchor_h | |
pred_x1 = pred_box_xc - pred_box_w / 2 | |
pred_y1 = pred_box_yc - pred_box_h / 2 | |
pred_x2 = pred_box_xc + pred_box_w / 2 | |
pred_y2 = pred_box_yc + pred_box_h / 2 | |
gt_box_xc = (grid_x + tgt_tx) * stride_x | |
gt_box_yc = (grid_y + tgt_ty) * stride_y | |
gt_box_w = torch.exp(tgt_tw) * anchor_w | |
gt_box_h = torch.exp(tgt_th) * anchor_h | |
gt_x1 = gt_box_xc - gt_box_w / 2 | |
gt_y1 = gt_box_yc - gt_box_h /2 | |
gt_x2 = gt_box_xc + gt_box_w / 2 | |
gt_y2 = gt_box_yc + gt_box_h / 2 | |
with torch.no_grad(): | |
ignore_mask_buf = torch.zeros_like(tgt_obj, dtype=torch.bool) | |
noobj_flat = noobj_mask.view(-1) | |
obj_flat = obj_mask.view(-1) | |
px1f = pred_x1.view(-1) | |
py1f = pred_y1.view(-1) | |
px2f = pred_x2.view(-1) | |
py2f = pred_y2.view(-1) | |
gx1f = gt_x1.view(-1)[obj_flat] | |
gy1f = gt_y1.view(-1)[obj_flat] | |
gx2f = gt_x2.view(-1)[obj_flat] | |
gy2f = gt_y2.view(-1)[obj_flat] | |
if noobj_flat.sum() > 0 and obj_flat.sum() > 0: | |
noobj_idx = noobj_flat.nonzero(as_tuple=True)[0] | |
noobj_boxes_xyxy = torch.stack([px1f[noobj_idx], py1f[noobj_idx], px2f[noobj_idx], py2f[noobj_idx]], dim=-1) | |
obj_boxes_xyxy = torch.stack([gx1f, gy1f, gx2f, gy2f], dim=-1) | |
ious = box_iou_xyxy(noobj_boxes_xyxy, obj_boxes_xyxy) | |
best_iou, _ = ious.max(dim=1) | |
ignore_flags = (best_iou > self.ignore_thresh) | |
all_idx = noobj_idx[ignore_flags] | |
ignore_mask_buf.view(-1)[all_idx] = True | |
ignore_mask = ignore_mask_buf | |
obj_loss = self.bce_obj(pred_obj[obj_mask], torch.ones_like(pred_obj[obj_mask])) | |
obj_loss = obj_loss.mean() if obj_loss.numel() > 0 else torch.tensor(0., device=device) | |
noobj_mask_final = (noobj_mask & (~ignore_mask)) | |
noobj_loss = self.bce_obj(pred_obj[noobj_mask_final], torch.zeros_like(pred_obj[noobj_mask_final])) | |
noobj_loss = noobj_loss.mean() if noobj_loss.numel() > 0 else torch.tensor(0., device=device) | |
objectness_loss = obj_loss + self.lambda_noobj * noobj_loss | |
class_loss = torch.tensor(0., device=device, requires_grad=True) | |
if obj_mask.sum() > 0: | |
self.bce_cls.weight = self.class_weight.to(device) | |
cls_pred = pred_cls[obj_mask].to(device) | |
cls_gt = tgt_cls[obj_mask].to(device) | |
c_loss = self.bce_cls(cls_pred, cls_gt) | |
class_loss = c_loss.mean() | |
giou_loss = torch.tensor(0., device=device, requires_grad=True) | |
if obj_mask.sum() > 0: | |
px1_ = pred_x1[obj_mask] | |
py1_ = pred_y1[obj_mask] | |
px2_ = pred_x2[obj_mask] | |
py2_ = pred_y2[obj_mask] | |
p_xyxy = torch.stack([px1_,py1_,px2_,py2_], dim=-1) | |
gx1_ = gt_x1[obj_mask] | |
gy1_ = gt_y1[obj_mask] | |
gx2_ = gt_x2[obj_mask] | |
gy2_ = gt_y2[obj_mask] | |
g_xyxy = torch.stack([gx1_,gy1_,gx2_,gy2_], dim=-1) | |
giou = box_giou_xyxy(p_xyxy, g_xyxy) | |
giou_loss = (1. - giou).mean() | |
total_loss = objectness_loss + class_loss + giou_loss | |
if total_loss is None: | |
pass | |
return total_loss |