File size: 4,197 Bytes
c614b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from glob import glob
import numpy as np
from config import Config


config = Config()

eval_txts = sorted(glob('e_results/*_eval.txt'))
print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts])
score_panel = {}
sep = '&'
metrics = ['sm', 'wfm', 'hce']    # we used HCE for DIS and wFm for others.
if 'DIS5K' not in config.task:
    metrics.remove('hce')

for metric in metrics:
    print('Metric:', metric)
    current_line_nums = []
    for idx_et, eval_txt in enumerate(eval_txts):
        with open(eval_txt, 'r') as f:
            lines = [l for l in f.readlines()[3:] if '.' in l]
        current_line_nums.append(len(lines))
    for idx_et, eval_txt in enumerate(eval_txts):
        with open(eval_txt, 'r') as f:
            lines = [l for l in f.readlines()[3:] if '.' in l]
        for idx_line, line in enumerate(lines[:min(current_line_nums)]):    # Consist line numbers by the minimal result file.
            properties = line.strip().strip(sep).split(sep)
            dataset = properties[0].strip()
            ckpt = properties[1].strip()
            if int(ckpt.split('--epoch_')[-1].strip()) < 0:
                continue
            targe_idx = {
                'sm': [5, 2, 2, 5, 5, 2],
                'wfm': [3, 3, 8, 3, 3, 8],
                'hce': [7, -1, -1, 7, 7, -1]
            }[metric][['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'].index(config.task)]
            if metric != 'hce':
                score_sm = float(properties[targe_idx].strip())
            else:
                score_sm = int(properties[targe_idx].strip().strip('.'))
            if idx_et == 0:
                score_panel[ckpt] = []
            score_panel[ckpt].append(score_sm)

    metrics_min = ['hce', 'mae']
    max_or_min = min if metric in metrics_min else max
    score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x))

    good_models = []
    for k, v in score_panel.items():
        if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)):
            print(k, v)
            good_models.append(k)

    # Write
    with open(eval_txt, 'r') as f:
        lines = f.readlines()
    info4good_models = lines[:3]
    metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]]
    testset_mean_values = {metric_name: [] for metric_name in metric_names}
    for good_model in good_models:
        for idx_et, eval_txt in enumerate(eval_txts):
            with open(eval_txt, 'r') as f:
                lines = f.readlines()
            for line in lines:
                if set([good_model]) & set([_.strip() for _ in line.split(sep)]):
                    info4good_models.append(line)
                    metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]]
                    for idx_score, metric_score in enumerate(metric_scores):
                        testset_mean_values[metric_names[idx_score]].append(metric_score)

    if 'DIS5K' in config.task:
        testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()]  # [:-1] to remove DIS-VD
        sample_line_for_placing_mean_values = info4good_models[-2]
        numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:]
        for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)):
            numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value)
        testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n'
        info4good_models.append(testset_mean_line)
    info4good_models.append(lines[-1])
    info = ''.join(info4good_models)
    print(info)
    with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f:
        f.write(info + '\n')