Spaces:
Running
Running
import json | |
import numpy as np | |
import scipy.io as io | |
from tools.utils.utility import check_install | |
from tools.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area | |
def get_socre_A(gt_dir, pred_dict): | |
allInputs = 1 | |
def input_reading_mod(pred_dict): | |
"""This helper reads input from txt files""" | |
det = [] | |
n = len(pred_dict) | |
for i in range(n): | |
points = pred_dict[i]["points"] | |
text = pred_dict[i]["texts"] | |
point = ",".join(map( | |
str, | |
points.reshape(-1, ), )) | |
det.append([point, text]) | |
return det | |
def gt_reading_mod(gt_dict): | |
"""This helper reads groundtruths from mat files""" | |
gt = [] | |
n = len(gt_dict) | |
for i in range(n): | |
points = gt_dict[i]["points"].tolist() | |
h = len(points) | |
text = gt_dict[i]["text"] | |
xx = [ | |
np.array( | |
["x:"], dtype="<U2"), | |
0, | |
np.array( | |
["y:"], dtype="<U2"), | |
0, | |
np.array( | |
["#"], dtype="<U1"), | |
np.array( | |
["#"], dtype="<U1"), | |
] | |
t_x, t_y = [], [] | |
for j in range(h): | |
t_x.append(points[j][0]) | |
t_y.append(points[j][1]) | |
xx[1] = np.array([t_x], dtype="int16") | |
xx[3] = np.array([t_y], dtype="int16") | |
if text != "": | |
xx[4] = np.array([text], dtype="U{}".format(len(text))) | |
xx[5] = np.array(["c"], dtype="<U1") | |
gt.append(xx) | |
return gt | |
def detection_filtering(detections, groundtruths, threshold=0.5): | |
for gt_id, gt in enumerate(groundtruths): | |
if (gt[5] == "#") and (gt[1].shape[1] > 1): | |
gt_x = list(map(int, np.squeeze(gt[1]))) | |
gt_y = list(map(int, np.squeeze(gt[3]))) | |
for det_id, detection in enumerate(detections): | |
detection_orig = detection | |
detection = [float(x) for x in detection[0].split(",")] | |
detection = list(map(int, detection)) | |
det_x = detection[0::2] | |
det_y = detection[1::2] | |
det_gt_iou = iod(det_x, det_y, gt_x, gt_y) | |
if det_gt_iou > threshold: | |
detections[det_id] = [] | |
detections[:] = [item for item in detections if item != []] | |
return detections | |
def sigma_calculation(det_x, det_y, gt_x, gt_y): | |
""" | |
sigma = inter_area / gt_area | |
""" | |
return np.round( | |
(area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), | |
2) | |
def tau_calculation(det_x, det_y, gt_x, gt_y): | |
if area(det_x, det_y) == 0.0: | |
return 0 | |
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / | |
area(det_x, det_y)), 2) | |
##############################Initialization################################### | |
# global_sigma = [] | |
# global_tau = [] | |
# global_pred_str = [] | |
# global_gt_str = [] | |
############################################################################### | |
for input_id in range(allInputs): | |
if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and | |
(input_id != "Pascal_result_curved.txt") and | |
(input_id != "Pascal_result_non_curved.txt") and | |
(input_id != "Deteval_result.txt") and | |
(input_id != "Deteval_result_curved.txt") and | |
(input_id != "Deteval_result_non_curved.txt")): | |
detections = input_reading_mod(pred_dict) | |
groundtruths = gt_reading_mod(gt_dir) | |
detections = detection_filtering( | |
detections, | |
groundtruths) # filters detections overlapping with DC area | |
dc_id = [] | |
for i in range(len(groundtruths)): | |
if groundtruths[i][5] == "#": | |
dc_id.append(i) | |
cnt = 0 | |
for a in dc_id: | |
num = a - cnt | |
del groundtruths[num] | |
cnt += 1 | |
local_sigma_table = np.zeros((len(groundtruths), len(detections))) | |
local_tau_table = np.zeros((len(groundtruths), len(detections))) | |
local_pred_str = {} | |
local_gt_str = {} | |
for gt_id, gt in enumerate(groundtruths): | |
if len(detections) > 0: | |
for det_id, detection in enumerate(detections): | |
detection_orig = detection | |
detection = [float(x) for x in detection[0].split(",")] | |
detection = list(map(int, detection)) | |
pred_seq_str = detection_orig[1].strip() | |
det_x = detection[0::2] | |
det_y = detection[1::2] | |
gt_x = list(map(int, np.squeeze(gt[1]))) | |
gt_y = list(map(int, np.squeeze(gt[3]))) | |
gt_seq_str = str(gt[4].tolist()[0]) | |
local_sigma_table[gt_id, det_id] = sigma_calculation( | |
det_x, det_y, gt_x, gt_y) | |
local_tau_table[gt_id, det_id] = tau_calculation( | |
det_x, det_y, gt_x, gt_y) | |
local_pred_str[det_id] = pred_seq_str | |
local_gt_str[gt_id] = gt_seq_str | |
global_sigma = local_sigma_table | |
global_tau = local_tau_table | |
global_pred_str = local_pred_str | |
global_gt_str = local_gt_str | |
single_data = {} | |
single_data["sigma"] = global_sigma | |
single_data["global_tau"] = global_tau | |
single_data["global_pred_str"] = global_pred_str | |
single_data["global_gt_str"] = global_gt_str | |
return single_data | |
def get_socre_B(gt_dir, img_id, pred_dict): | |
allInputs = 1 | |
def input_reading_mod(pred_dict): | |
"""This helper reads input from txt files""" | |
det = [] | |
n = len(pred_dict) | |
for i in range(n): | |
points = pred_dict[i]["points"] | |
text = pred_dict[i]["texts"] | |
point = ",".join(map( | |
str, | |
points.reshape(-1, ), )) | |
det.append([point, text]) | |
return det | |
def gt_reading_mod(gt_dir, gt_id): | |
gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id)) | |
gt = gt["polygt"] | |
return gt | |
def detection_filtering(detections, groundtruths, threshold=0.5): | |
for gt_id, gt in enumerate(groundtruths): | |
if (gt[5] == "#") and (gt[1].shape[1] > 1): | |
gt_x = list(map(int, np.squeeze(gt[1]))) | |
gt_y = list(map(int, np.squeeze(gt[3]))) | |
for det_id, detection in enumerate(detections): | |
detection_orig = detection | |
detection = [float(x) for x in detection[0].split(",")] | |
detection = list(map(int, detection)) | |
det_x = detection[0::2] | |
det_y = detection[1::2] | |
det_gt_iou = iod(det_x, det_y, gt_x, gt_y) | |
if det_gt_iou > threshold: | |
detections[det_id] = [] | |
detections[:] = [item for item in detections if item != []] | |
return detections | |
def sigma_calculation(det_x, det_y, gt_x, gt_y): | |
""" | |
sigma = inter_area / gt_area | |
""" | |
return np.round( | |
(area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), | |
2) | |
def tau_calculation(det_x, det_y, gt_x, gt_y): | |
if area(det_x, det_y) == 0.0: | |
return 0 | |
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / | |
area(det_x, det_y)), 2) | |
##############################Initialization################################### | |
# global_sigma = [] | |
# global_tau = [] | |
# global_pred_str = [] | |
# global_gt_str = [] | |
############################################################################### | |
for input_id in range(allInputs): | |
if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and | |
(input_id != "Pascal_result_curved.txt") and | |
(input_id != "Pascal_result_non_curved.txt") and | |
(input_id != "Deteval_result.txt") and | |
(input_id != "Deteval_result_curved.txt") and | |
(input_id != "Deteval_result_non_curved.txt")): | |
detections = input_reading_mod(pred_dict) | |
groundtruths = gt_reading_mod(gt_dir, img_id).tolist() | |
detections = detection_filtering( | |
detections, | |
groundtruths) # filters detections overlapping with DC area | |
dc_id = [] | |
for i in range(len(groundtruths)): | |
if groundtruths[i][5] == "#": | |
dc_id.append(i) | |
cnt = 0 | |
for a in dc_id: | |
num = a - cnt | |
del groundtruths[num] | |
cnt += 1 | |
local_sigma_table = np.zeros((len(groundtruths), len(detections))) | |
local_tau_table = np.zeros((len(groundtruths), len(detections))) | |
local_pred_str = {} | |
local_gt_str = {} | |
for gt_id, gt in enumerate(groundtruths): | |
if len(detections) > 0: | |
for det_id, detection in enumerate(detections): | |
detection_orig = detection | |
detection = [float(x) for x in detection[0].split(",")] | |
detection = list(map(int, detection)) | |
pred_seq_str = detection_orig[1].strip() | |
det_x = detection[0::2] | |
det_y = detection[1::2] | |
gt_x = list(map(int, np.squeeze(gt[1]))) | |
gt_y = list(map(int, np.squeeze(gt[3]))) | |
gt_seq_str = str(gt[4].tolist()[0]) | |
local_sigma_table[gt_id, det_id] = sigma_calculation( | |
det_x, det_y, gt_x, gt_y) | |
local_tau_table[gt_id, det_id] = tau_calculation( | |
det_x, det_y, gt_x, gt_y) | |
local_pred_str[det_id] = pred_seq_str | |
local_gt_str[gt_id] = gt_seq_str | |
global_sigma = local_sigma_table | |
global_tau = local_tau_table | |
global_pred_str = local_pred_str | |
global_gt_str = local_gt_str | |
single_data = {} | |
single_data["sigma"] = global_sigma | |
single_data["global_tau"] = global_tau | |
single_data["global_pred_str"] = global_pred_str | |
single_data["global_gt_str"] = global_gt_str | |
return single_data | |
def get_score_C(gt_label, text, pred_bboxes): | |
""" | |
get score for CentripetalText (CT) prediction. | |
""" | |
check_install("Polygon", "Polygon3") | |
import Polygon as plg | |
def gt_reading_mod(gt_label, text): | |
"""This helper reads groundtruths from mat files""" | |
groundtruths = [] | |
nbox = len(gt_label) | |
for i in range(nbox): | |
label = {"transcription": text[i][0], "points": gt_label[i].numpy()} | |
groundtruths.append(label) | |
return groundtruths | |
def get_union(pD, pG): | |
areaA = pD.area() | |
areaB = pG.area() | |
return areaA + areaB - get_intersection(pD, pG) | |
def get_intersection(pD, pG): | |
pInt = pD & pG | |
if len(pInt) == 0: | |
return 0 | |
return pInt.area() | |
def detection_filtering(detections, groundtruths, threshold=0.5): | |
for gt in groundtruths: | |
point_num = gt["points"].shape[1] // 2 | |
if gt["transcription"] == "###" and (point_num > 1): | |
gt_p = np.array(gt["points"]).reshape(point_num, | |
2).astype("int32") | |
gt_p = plg.Polygon(gt_p) | |
for det_id, detection in enumerate(detections): | |
det_y = detection[0::2] | |
det_x = detection[1::2] | |
det_p = np.concatenate((np.array(det_x), np.array(det_y))) | |
det_p = det_p.reshape(2, -1).transpose() | |
det_p = plg.Polygon(det_p) | |
try: | |
det_gt_iou = get_intersection(det_p, | |
gt_p) / det_p.area() | |
except: | |
print(det_x, det_y, gt_p) | |
if det_gt_iou > threshold: | |
detections[det_id] = [] | |
detections[:] = [item for item in detections if item != []] | |
return detections | |
def sigma_calculation(det_p, gt_p): | |
""" | |
sigma = inter_area / gt_area | |
""" | |
if gt_p.area() == 0.0: | |
return 0 | |
return get_intersection(det_p, gt_p) / gt_p.area() | |
def tau_calculation(det_p, gt_p): | |
""" | |
tau = inter_area / det_area | |
""" | |
if det_p.area() == 0.0: | |
return 0 | |
return get_intersection(det_p, gt_p) / det_p.area() | |
detections = [] | |
for item in pred_bboxes: | |
detections.append(item[:, ::-1].reshape(-1)) | |
groundtruths = gt_reading_mod(gt_label, text) | |
detections = detection_filtering( | |
detections, groundtruths) # filters detections overlapping with DC area | |
for idx in range(len(groundtruths) - 1, -1, -1): | |
# NOTE: source code use 'orin' to indicate '#', here we use 'anno', | |
# which may cause slight drop in fscore, about 0.12 | |
if groundtruths[idx]["transcription"] == "###": | |
groundtruths.pop(idx) | |
local_sigma_table = np.zeros((len(groundtruths), len(detections))) | |
local_tau_table = np.zeros((len(groundtruths), len(detections))) | |
for gt_id, gt in enumerate(groundtruths): | |
if len(detections) > 0: | |
for det_id, detection in enumerate(detections): | |
point_num = gt["points"].shape[1] // 2 | |
gt_p = np.array(gt["points"]).reshape(point_num, | |
2).astype("int32") | |
gt_p = plg.Polygon(gt_p) | |
det_y = detection[0::2] | |
det_x = detection[1::2] | |
det_p = np.concatenate((np.array(det_x), np.array(det_y))) | |
det_p = det_p.reshape(2, -1).transpose() | |
det_p = plg.Polygon(det_p) | |
local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, | |
gt_p) | |
local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p) | |
data = {} | |
data["sigma"] = local_sigma_table | |
data["global_tau"] = local_tau_table | |
data["global_pred_str"] = "" | |
data["global_gt_str"] = "" | |
return data | |
def combine_results(all_data, rec_flag=True): | |
tr = 0.7 | |
tp = 0.6 | |
fsc_k = 0.8 | |
k = 2 | |
global_sigma = [] | |
global_tau = [] | |
global_pred_str = [] | |
global_gt_str = [] | |
for data in all_data: | |
global_sigma.append(data["sigma"]) | |
global_tau.append(data["global_tau"]) | |
global_pred_str.append(data["global_pred_str"]) | |
global_gt_str.append(data["global_gt_str"]) | |
global_accumulative_recall = 0 | |
global_accumulative_precision = 0 | |
total_num_gt = 0 | |
total_num_det = 0 | |
hit_str_count = 0 | |
hit_count = 0 | |
def one_to_one( | |
local_sigma_table, | |
local_tau_table, | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
idy, | |
rec_flag, ): | |
hit_str_num = 0 | |
for gt_id in range(num_gt): | |
gt_matching_qualified_sigma_candidates = np.where( | |
local_sigma_table[gt_id, :] > tr) | |
gt_matching_num_qualified_sigma_candidates = ( | |
gt_matching_qualified_sigma_candidates[0].shape[0]) | |
gt_matching_qualified_tau_candidates = np.where( | |
local_tau_table[gt_id, :] > tp) | |
gt_matching_num_qualified_tau_candidates = ( | |
gt_matching_qualified_tau_candidates[0].shape[0]) | |
det_matching_qualified_sigma_candidates = np.where( | |
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] | |
> tr) | |
det_matching_num_qualified_sigma_candidates = ( | |
det_matching_qualified_sigma_candidates[0].shape[0]) | |
det_matching_qualified_tau_candidates = np.where( | |
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > | |
tp) | |
det_matching_num_qualified_tau_candidates = ( | |
det_matching_qualified_tau_candidates[0].shape[0]) | |
if ((gt_matching_num_qualified_sigma_candidates == 1) and | |
(gt_matching_num_qualified_tau_candidates == 1) and | |
(det_matching_num_qualified_sigma_candidates == 1) and | |
(det_matching_num_qualified_tau_candidates == 1)): | |
global_accumulative_recall = global_accumulative_recall + 1.0 | |
global_accumulative_precision = global_accumulative_precision + 1.0 | |
local_accumulative_recall = local_accumulative_recall + 1.0 | |
local_accumulative_precision = local_accumulative_precision + 1.0 | |
gt_flag[0, gt_id] = 1 | |
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) | |
# recg start | |
if rec_flag: | |
gt_str_cur = global_gt_str[idy][gt_id] | |
pred_str_cur = global_pred_str[idy][matched_det_id[0] | |
.tolist()[0]] | |
if pred_str_cur == gt_str_cur: | |
hit_str_num += 1 | |
else: | |
if pred_str_cur.lower() == gt_str_cur.lower(): | |
hit_str_num += 1 | |
# recg end | |
det_flag[0, matched_det_id] = 1 | |
return ( | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
hit_str_num, ) | |
def one_to_many( | |
local_sigma_table, | |
local_tau_table, | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
idy, | |
rec_flag, ): | |
hit_str_num = 0 | |
for gt_id in range(num_gt): | |
# skip the following if the groundtruth was matched | |
if gt_flag[0, gt_id] > 0: | |
continue | |
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) | |
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] | |
if num_non_zero_in_sigma >= k: | |
####search for all detections that overlaps with this groundtruth | |
qualified_tau_candidates = np.where((local_tau_table[ | |
gt_id, :] >= tp) & (det_flag[0, :] == 0)) | |
num_qualified_tau_candidates = qualified_tau_candidates[ | |
0].shape[0] | |
if num_qualified_tau_candidates == 1: | |
if (local_tau_table[gt_id, qualified_tau_candidates] >= tp | |
) and ( | |
local_sigma_table[gt_id, qualified_tau_candidates] | |
>= tr): | |
# became an one-to-one case | |
global_accumulative_recall = global_accumulative_recall + 1.0 | |
global_accumulative_precision = ( | |
global_accumulative_precision + 1.0) | |
local_accumulative_recall = local_accumulative_recall + 1.0 | |
local_accumulative_precision = ( | |
local_accumulative_precision + 1.0) | |
gt_flag[0, gt_id] = 1 | |
det_flag[0, qualified_tau_candidates] = 1 | |
# recg start | |
if rec_flag: | |
gt_str_cur = global_gt_str[idy][gt_id] | |
pred_str_cur = global_pred_str[idy][ | |
qualified_tau_candidates[0].tolist()[0]] | |
if pred_str_cur == gt_str_cur: | |
hit_str_num += 1 | |
else: | |
if pred_str_cur.lower() == gt_str_cur.lower(): | |
hit_str_num += 1 | |
# recg end | |
elif np.sum(local_sigma_table[gt_id, | |
qualified_tau_candidates]) >= tr: | |
gt_flag[0, gt_id] = 1 | |
det_flag[0, qualified_tau_candidates] = 1 | |
# recg start | |
if rec_flag: | |
gt_str_cur = global_gt_str[idy][gt_id] | |
pred_str_cur = global_pred_str[idy][ | |
qualified_tau_candidates[0].tolist()[0]] | |
if pred_str_cur == gt_str_cur: | |
hit_str_num += 1 | |
else: | |
if pred_str_cur.lower() == gt_str_cur.lower(): | |
hit_str_num += 1 | |
# recg end | |
global_accumulative_recall = global_accumulative_recall + fsc_k | |
global_accumulative_precision = ( | |
global_accumulative_precision + | |
num_qualified_tau_candidates * fsc_k) | |
local_accumulative_recall = local_accumulative_recall + fsc_k | |
local_accumulative_precision = ( | |
local_accumulative_precision + | |
num_qualified_tau_candidates * fsc_k) | |
return ( | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
hit_str_num, ) | |
def many_to_one( | |
local_sigma_table, | |
local_tau_table, | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
idy, | |
rec_flag, ): | |
hit_str_num = 0 | |
for det_id in range(num_det): | |
# skip the following if the detection was matched | |
if det_flag[0, det_id] > 0: | |
continue | |
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) | |
num_non_zero_in_tau = non_zero_in_tau[0].shape[0] | |
if num_non_zero_in_tau >= k: | |
####search for all detections that overlaps with this groundtruth | |
qualified_sigma_candidates = np.where(( | |
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) | |
num_qualified_sigma_candidates = qualified_sigma_candidates[ | |
0].shape[0] | |
if num_qualified_sigma_candidates == 1: | |
if ( | |
local_tau_table[qualified_sigma_candidates, det_id] | |
>= tp | |
) and (local_sigma_table[qualified_sigma_candidates, det_id] | |
>= tr): | |
# became an one-to-one case | |
global_accumulative_recall = global_accumulative_recall + 1.0 | |
global_accumulative_precision = ( | |
global_accumulative_precision + 1.0) | |
local_accumulative_recall = local_accumulative_recall + 1.0 | |
local_accumulative_precision = ( | |
local_accumulative_precision + 1.0) | |
gt_flag[0, qualified_sigma_candidates] = 1 | |
det_flag[0, det_id] = 1 | |
# recg start | |
if rec_flag: | |
pred_str_cur = global_pred_str[idy][det_id] | |
gt_len = len(qualified_sigma_candidates[0]) | |
for idx in range(gt_len): | |
ele_gt_id = qualified_sigma_candidates[ | |
0].tolist()[idx] | |
if ele_gt_id not in global_gt_str[idy]: | |
continue | |
gt_str_cur = global_gt_str[idy][ele_gt_id] | |
if pred_str_cur == gt_str_cur: | |
hit_str_num += 1 | |
break | |
else: | |
if pred_str_cur.lower() == gt_str_cur.lower( | |
): | |
hit_str_num += 1 | |
break | |
# recg end | |
elif np.sum(local_tau_table[qualified_sigma_candidates, | |
det_id]) >= tp: | |
det_flag[0, det_id] = 1 | |
gt_flag[0, qualified_sigma_candidates] = 1 | |
# recg start | |
if rec_flag: | |
pred_str_cur = global_pred_str[idy][det_id] | |
gt_len = len(qualified_sigma_candidates[0]) | |
for idx in range(gt_len): | |
ele_gt_id = qualified_sigma_candidates[0].tolist()[ | |
idx] | |
if ele_gt_id not in global_gt_str[idy]: | |
continue | |
gt_str_cur = global_gt_str[idy][ele_gt_id] | |
if pred_str_cur == gt_str_cur: | |
hit_str_num += 1 | |
break | |
else: | |
if pred_str_cur.lower() == gt_str_cur.lower(): | |
hit_str_num += 1 | |
break | |
# recg end | |
global_accumulative_recall = ( | |
global_accumulative_recall + | |
num_qualified_sigma_candidates * fsc_k) | |
global_accumulative_precision = ( | |
global_accumulative_precision + fsc_k) | |
local_accumulative_recall = ( | |
local_accumulative_recall + | |
num_qualified_sigma_candidates * fsc_k) | |
local_accumulative_precision = local_accumulative_precision + fsc_k | |
return ( | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
hit_str_num, ) | |
for idx in range(len(global_sigma)): | |
local_sigma_table = np.array(global_sigma[idx]) | |
local_tau_table = global_tau[idx] | |
num_gt = local_sigma_table.shape[0] | |
num_det = local_sigma_table.shape[1] | |
total_num_gt = total_num_gt + num_gt | |
total_num_det = total_num_det + num_det | |
local_accumulative_recall = 0 | |
local_accumulative_precision = 0 | |
gt_flag = np.zeros((1, num_gt)) | |
det_flag = np.zeros((1, num_det)) | |
#######first check for one-to-one case########## | |
( | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
hit_str_num, ) = one_to_one( | |
local_sigma_table, | |
local_tau_table, | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
idx, | |
rec_flag, ) | |
hit_str_count += hit_str_num | |
#######then check for one-to-many case########## | |
( | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
hit_str_num, ) = one_to_many( | |
local_sigma_table, | |
local_tau_table, | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
idx, | |
rec_flag, ) | |
hit_str_count += hit_str_num | |
#######then check for many-to-one case########## | |
( | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
hit_str_num, ) = many_to_one( | |
local_sigma_table, | |
local_tau_table, | |
local_accumulative_recall, | |
local_accumulative_precision, | |
global_accumulative_recall, | |
global_accumulative_precision, | |
gt_flag, | |
det_flag, | |
idx, | |
rec_flag, ) | |
hit_str_count += hit_str_num | |
try: | |
recall = global_accumulative_recall / total_num_gt | |
except ZeroDivisionError: | |
recall = 0 | |
try: | |
precision = global_accumulative_precision / total_num_det | |
except ZeroDivisionError: | |
precision = 0 | |
try: | |
f_score = 2 * precision * recall / (precision + recall) | |
except ZeroDivisionError: | |
f_score = 0 | |
try: | |
seqerr = 1 - float(hit_str_count) / global_accumulative_recall | |
except ZeroDivisionError: | |
seqerr = 1 | |
try: | |
recall_e2e = float(hit_str_count) / total_num_gt | |
except ZeroDivisionError: | |
recall_e2e = 0 | |
try: | |
precision_e2e = float(hit_str_count) / total_num_det | |
except ZeroDivisionError: | |
precision_e2e = 0 | |
try: | |
f_score_e2e = 2 * precision_e2e * recall_e2e / ( | |
precision_e2e + recall_e2e) | |
except ZeroDivisionError: | |
f_score_e2e = 0 | |
final = { | |
"total_num_gt": total_num_gt, | |
"total_num_det": total_num_det, | |
"global_accumulative_recall": global_accumulative_recall, | |
"hit_str_count": hit_str_count, | |
"recall": recall, | |
"precision": precision, | |
"f_score": f_score, | |
"seqerr": seqerr, | |
"recall_e2e": recall_e2e, | |
"precision_e2e": precision_e2e, | |
"f_score_e2e": f_score_e2e, | |
} | |
return final | |