File size: 1,453 Bytes
91ef820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import numpy as np

from utils.box_utils import bbox_iou, xywh2xyxy


def trans_vg_eval_val(pred_boxes, gt_boxes):
    batch_size = pred_boxes.shape[0]
    pred_boxes = xywh2xyxy(pred_boxes)
    pred_boxes = torch.clamp(pred_boxes, 0, 1)
    gt_boxes = xywh2xyxy(gt_boxes)
    iou = bbox_iou(pred_boxes, gt_boxes)
    accu = torch.sum(iou >= 0.5) / float(batch_size)

    return iou, accu

def trans_vg_eval_test(pred_boxes, gt_boxes, sum=True):
    pred_boxes = xywh2xyxy(pred_boxes)
    pred_boxes = torch.clamp(pred_boxes, 0, 1)
    gt_boxes = xywh2xyxy(gt_boxes)
    iou = bbox_iou(pred_boxes, gt_boxes)
    accu = torch.sum(iou >= 0.5) if sum else iou >= 0.5
    return iou, accu

def eval_category(category_id_list, iou, accu):
    # category_id_list包含 [1,2,3...8]id子类,此处需要 -1 表示从0编码
    category_id_list = category_id_list.cpu().numpy()
    sub = list(set(category_id_list)).__len__()
    iou = iou.cpu().numpy()
    accu = accu.cpu().numpy()
    category_iou = [0] * sub
    category_accu = [0] * sub
    sub_num = [0] * sub
    for (id, iou_, accu_) in zip(category_id_list, iou, accu):
        category_iou[id-1] += iou_
        category_accu[id-1] += accu_
        sub_num[id-1] += 1
    category_iou = [i / s for i, s in zip(category_iou, sub_num)]
    category_accu = [a / s for a, s in zip(category_accu, sub_num)]
    return category_iou, category_accu