Realcat
update: major change
499e141
raw
history blame
4.19 kB
# -*- coding: utf-8 -*-
# @Author : xuelun
import os
import argparse
import numpy as np
from os.path import join
from datetime import datetime
angular_thresholds = ['5.0°']
dist_thresholds = ['0.1m']
intt = lambda x: list(map(int, x))
floatt = lambda x: list(map(float, x))
strr = lambda x: list(map(lambda x:f'{x:.18f}', x))
datasets = [
'GL3D',
'BlendedMVS',
'ETH3DI',
'ETH3DO',
'KITTI',
'RobotcarWeather',
'RobotcarSeason',
'RobotcarNight',
'Multi-FoV',
'SceneNetRGBD',
'ICL-NUIM',
'GTA-SfM',
]
def error_auc(errs0, errs1, thres, metric):
if isinstance(errs0, list): errs0 = np.array(errs0)
if isinstance(errs1, list): errs1 = np.array(errs1)
if any(np.isnan(errs0)): errs0[np.isnan(errs0)] = 180
if any(np.isnan(errs1)): errs1[np.isnan(errs1)] = 180
if any(np.isinf(errs0)): errs0[np.isinf(errs0)] = 180
if any(np.isinf(errs1)): errs1[np.isinf(errs1)] = 180
errors = np.max(np.stack([errs0, errs1]), axis=0)
errors = [0] + sorted(list(errors))
recall = list(np.linspace(0, 1, len(errors)))
aucs = []
for thr in thres:
thr = float(thr[:-1])
last_index = np.searchsorted(errors, thr)
y = recall[:last_index] + [recall[last_index-1]]
x = errors[:last_index] + [thr]
aucs.append(np.trapz(y, x) / thr)
return {f'{metric}@ {t}': auc for t, auc in zip(thres, aucs)}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, default='.')
parser.add_argument('--wid', type=str, required=True)
parser.add_argument('--version', type=str, default=None)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--log', action='store_true')
parser.add_argument('--sceids', type=str, choices=datasets, nargs='+',
default=None, help=f'Test Datasets: {datasets}', )
opt = parser.parse_args()
dir = opt.dir
wid = opt.wid
version = opt.version
_data = \
{
x.rpartition('.txt')[0].split()[2]:x for x in
[
d for d in os.listdir(dir) if not os.path.isdir(os.path.join(dir, d))
] if wid == x.rpartition('.txt')[0].split()[1] and version is not None and version == x.rpartition('.txt')[0].split()[-1]
}
_data = {k:_data[k] for k in datasets if k in _data.keys()}
sceids = opt.sceids
sceids = sceids if sceids is not None else _data.keys()
results = {}
for sceid in sceids:
results[sceid] = {}
if not opt.verbose: print('{:^13} {}'.format(sceid, wid))
# read txt
with open(join(dir, _data[sceid]), 'r') as f:
data = f.readlines()
head = data[0].split()
content = [x.split() for x in data[1:]]
details = {k: [] for k in head[3:]}
stacks = []
for x in content:
ids = x[0]
if ids in stacks: continue
for k, v in zip(head[3:], x[3:]): details[k].append(v)
stacks.append(ids)
mAP = error_auc(floatt(details['R_errs']), floatt(details['t_errs']), angular_thresholds, 'auc')
for k, v in mAP.items(): results[sceid][k] = v
# print head
output = ''
num = 56+25*len(sceids)
output += '='*num
output += "\n"
output += '{:<25}'.format(datetime.now().strftime("%Y-%m-%d, %H:%M:%S"))
output += '{:<15} '.format('Model')
output += '{:<14} '.format('Metric')
for sceid in sceids: output += '{:<25} '.format(sceid)
output += "\n"
output += '-'*num
output += "\n"
for k in list(results.values())[0].keys():
output += '{:<25}'.format(datetime.now().strftime("%Y-%m-%d, %H:%M:%S")) if opt.log else '{:<25}'.format(' ')
output += '{:<15} '.format(wid)
output += '{:<14} '.format(k)
for sceid in sceids:
output += '{:<25} '.format(results[sceid][k])
output += "\n"
output += '='*num
output += "\n"
output += "\n"
if opt.verbose:
print(output)
if opt.log:
path = 'ANALYSIS RESULTS.txt'
with open(path, 'a') as file:
file.write(output)