File size: 3,384 Bytes
d4e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding: utf-8 -*-
# Author: Gaojian Wang@ZJUICSR
# --------------------------------------------------------
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
# You can find the license in the LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc, accuracy_score, balanced_accuracy_score
from scipy.optimize import brentq
from scipy.interpolate import interp1d


def frame_level_acc(labels, y_preds):
    return accuracy_score(labels, y_preds) * 100.


def frame_level_balanced_acc(labels, y_preds):
    return balanced_accuracy_score(labels, y_preds) * 100.


def frame_level_auc(labels, preds):
    return roc_auc_score(labels, preds) * 100.


def frame_level_eer(labels, preds):
    # 推荐;更正确的,MaskRelation(TIFS23也是)
    fpr, tpr, thresholds = roc_curve(labels, preds, pos_label=1)  # 如果标签不是二进制的,则应显式地给出pos_标签
    eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    # eer_thresh = interp1d(fpr, thresholds)(eer)
    return eer


# def frame_level_eer(labels, preds):
#     fpr, tpr, thresholds = roc_curve(labels, preds, pos_label=1)
#     eer_threshold = thresholds[(fpr + (1 - tpr)).argmin()]
#     fpr_eer = fpr[thresholds == eer_threshold][0]
#     fnr_eer = 1 - tpr[thresholds == eer_threshold][0]
#     eer = (fpr_eer + fnr_eer) / 2
#     metric_logger.meters['eer'].update(eer)
#     return eer, eer_thresh


def get_video_level_label_pred(f_label_list, v_name_list, f_pred_list):
    """
    References:
    CADDM: https://github.com/megvii-research/CADDM
    """
    video_res_dict = dict()
    video_pred_list = list()
    video_y_pred_list = list()
    video_label_list = list()
    # summarize all the results for each video
    for label, video, score in zip(f_label_list, v_name_list, f_pred_list):
        if video not in video_res_dict.keys():
            video_res_dict[video] = {"scores": [score], "label": label}
        else:
            video_res_dict[video]["scores"].append(score)
    # get the score and label for each video
    for video, res in video_res_dict.items():
        score = sum(res['scores']) / len(res['scores'])
        label = res['label']
        video_pred_list.append(score)
        video_label_list.append(label)
        video_y_pred_list.append(score >= 0.5)

    return video_label_list, video_pred_list, video_y_pred_list


def video_level_acc(video_label_list, video_y_pred_list):
    return accuracy_score(video_label_list, video_y_pred_list) * 100.


def video_level_balanced_acc(video_label_list, video_y_pred_list):
    return balanced_accuracy_score(video_label_list, video_y_pred_list) * 100.


def video_level_auc(video_label_list, video_pred_list):
    return roc_auc_score(video_label_list, video_pred_list) * 100.


def video_level_eer(video_label_list, video_pred_list):
    # 推荐;更正确的,MaskRelation(TIFS23也是)
    fpr, tpr, thresholds = roc_curve(video_label_list, video_pred_list, pos_label=1)  # 如果标签不是二进制的,则应显式地给出pos_标签
    v_eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    # eer_thresh = interp1d(fpr, thresholds)(eer)
    return v_eer