|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import sys |
|
import shapely |
|
from shapely.geometry import Polygon |
|
import numpy as np |
|
from collections import defaultdict |
|
import operator |
|
from rapidfuzz.distance import Levenshtein |
|
import argparse |
|
import json |
|
import copy |
|
|
|
|
|
def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert fp_type in ["gt", "pred"] |
|
key = "label" if fp_type == "gt" else "pred" |
|
res_dict = dict() |
|
with open(fp, "r", encoding='utf-8') as fin: |
|
lines = fin.readlines() |
|
|
|
for _, line in enumerate(lines): |
|
img_path, info = line.strip().split("\t") |
|
|
|
image_name = os.path.basename(img_path) |
|
res_dict[image_name] = [] |
|
|
|
json_info = json.loads(info) |
|
for single_ocr_info in json_info["ocr_info"]: |
|
label = single_ocr_info[key].upper() |
|
if label in ["O", "OTHERS", "OTHER"]: |
|
label = "O" |
|
if ignore_background and label == "O": |
|
continue |
|
single_ocr_info["label"] = label |
|
res_dict[image_name].append(copy.deepcopy(single_ocr_info)) |
|
return res_dict |
|
|
|
|
|
def polygon_from_str(polygon_points): |
|
""" |
|
Create a shapely polygon object from gt or dt line. |
|
""" |
|
polygon_points = np.array(polygon_points).reshape(4, 2) |
|
polygon = Polygon(polygon_points).convex_hull |
|
return polygon |
|
|
|
|
|
def polygon_iou(poly1, poly2): |
|
""" |
|
Intersection over union between two shapely polygons. |
|
""" |
|
if not poly1.intersects( |
|
poly2): |
|
iou = 0 |
|
else: |
|
try: |
|
inter_area = poly1.intersection(poly2).area |
|
union_area = poly1.area + poly2.area - inter_area |
|
iou = float(inter_area) / union_area |
|
except shapely.geos.TopologicalError: |
|
|
|
|
|
print('shapely.geos.TopologicalError occurred, iou set to 0') |
|
iou = 0 |
|
return iou |
|
|
|
|
|
def ed(args, str1, str2): |
|
if args.ignore_space: |
|
str1 = str1.replace(" ", "") |
|
str2 = str2.replace(" ", "") |
|
if args.ignore_case: |
|
str1 = str1.lower() |
|
str2 = str2.lower() |
|
return Levenshtein.distance(str1, str2) |
|
|
|
|
|
def convert_bbox_to_polygon(bbox): |
|
""" |
|
bbox : [x1, y1, x2, y2] |
|
output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] |
|
""" |
|
xmin, ymin, xmax, ymax = bbox |
|
poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] |
|
return poly |
|
|
|
|
|
def eval_e2e(args): |
|
|
|
gt_results = parse_ser_results_fp(args.gt_json_path, "gt", |
|
args.ignore_background) |
|
|
|
dt_results = parse_ser_results_fp(args.pred_json_path, "pred", |
|
args.ignore_background) |
|
iou_thresh = args.iou_thres |
|
num_gt_chars = 0 |
|
gt_count = 0 |
|
dt_count = 0 |
|
hit = 0 |
|
ed_sum = 0 |
|
|
|
for img_name in dt_results: |
|
gt_info = gt_results[img_name] |
|
gt_count += len(gt_info) |
|
|
|
dt_info = dt_results[img_name] |
|
dt_count += len(dt_info) |
|
|
|
dt_match = [False] * len(dt_info) |
|
gt_match = [False] * len(gt_info) |
|
|
|
all_ious = defaultdict(tuple) |
|
|
|
for index_gt, gt in enumerate(gt_info): |
|
if "poly" not in gt: |
|
gt["poly"] = convert_bbox_to_polygon(gt["bbox"]) |
|
gt_poly = polygon_from_str(gt["poly"]) |
|
for index_dt, dt in enumerate(dt_info): |
|
if "poly" not in dt: |
|
dt["poly"] = convert_bbox_to_polygon(dt["bbox"]) |
|
dt_poly = polygon_from_str(dt["poly"]) |
|
iou = polygon_iou(dt_poly, gt_poly) |
|
if iou >= iou_thresh: |
|
all_ious[(index_gt, index_dt)] = iou |
|
sorted_ious = sorted( |
|
all_ious.items(), key=operator.itemgetter(1), reverse=True) |
|
sorted_gt_dt_pairs = [item[0] for item in sorted_ious] |
|
|
|
|
|
for gt_dt_pair in sorted_gt_dt_pairs: |
|
index_gt, index_dt = gt_dt_pair |
|
if gt_match[index_gt] == False and dt_match[index_dt] == False: |
|
gt_match[index_gt] = True |
|
dt_match[index_dt] = True |
|
|
|
gt_text = gt_info[index_gt]["text"] |
|
dt_text = dt_info[index_dt]["text"] |
|
|
|
|
|
gt_label = gt_info[index_gt]["label"] |
|
dt_label = dt_info[index_dt]["pred"] |
|
|
|
if True: |
|
ed_sum += ed(args, gt_text, dt_text) |
|
num_gt_chars += len(gt_text) |
|
if gt_text == dt_text: |
|
if args.ignore_ser_prediction or gt_label == dt_label: |
|
hit += 1 |
|
|
|
|
|
for tindex, dt_match_flag in enumerate(dt_match): |
|
if dt_match_flag == False: |
|
dt_text = dt_info[tindex]["text"] |
|
gt_text = "" |
|
ed_sum += ed(args, dt_text, gt_text) |
|
|
|
|
|
for tindex, gt_match_flag in enumerate(gt_match): |
|
if gt_match_flag == False: |
|
dt_text = "" |
|
gt_text = gt_info[tindex]["text"] |
|
ed_sum += ed(args, gt_text, dt_text) |
|
num_gt_chars += len(gt_text) |
|
|
|
eps = 1e-9 |
|
print("config: ", args) |
|
print('hit, dt_count, gt_count', hit, dt_count, gt_count) |
|
precision = hit / (dt_count + eps) |
|
recall = hit / (gt_count + eps) |
|
fmeasure = 2.0 * precision * recall / (precision + recall + eps) |
|
avg_edit_dist_img = ed_sum / len(gt_results) |
|
avg_edit_dist_field = ed_sum / (gt_count + eps) |
|
character_acc = 1 - ed_sum / (num_gt_chars + eps) |
|
|
|
print('character_acc: %.2f' % (character_acc * 100) + "%") |
|
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field)) |
|
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img)) |
|
print('precision: %.2f' % (precision * 100) + "%") |
|
print('recall: %.2f' % (recall * 100) + "%") |
|
print('fmeasure: %.2f' % (fmeasure * 100) + "%") |
|
|
|
return |
|
|
|
|
|
def parse_args(): |
|
""" |
|
""" |
|
|
|
def str2bool(v): |
|
return v.lower() in ("true", "t", "1") |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--gt_json_path", |
|
default=None, |
|
type=str, |
|
required=True, ) |
|
parser.add_argument( |
|
"--pred_json_path", |
|
default=None, |
|
type=str, |
|
required=True, ) |
|
|
|
parser.add_argument("--iou_thres", default=0.5, type=float) |
|
|
|
parser.add_argument( |
|
"--ignore_case", |
|
default=False, |
|
type=str2bool, |
|
help="whether to do lower case for the strs") |
|
|
|
parser.add_argument( |
|
"--ignore_space", |
|
default=True, |
|
type=str2bool, |
|
help="whether to ignore space") |
|
|
|
parser.add_argument( |
|
"--ignore_background", |
|
default=True, |
|
type=str2bool, |
|
help="whether to ignore other label") |
|
|
|
parser.add_argument( |
|
"--ignore_ser_prediction", |
|
default=False, |
|
type=str2bool, |
|
help="whether to ignore ocr pred results") |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
eval_e2e(args) |
|
|