File size: 4,137 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import ast
import json
from tqdm import tqdm, trange
from collections import defaultdict
import re
import os, sys
import pdb
import sys, json

def clean_special_tokens(input_string):
    pattern = "<.*?>"
    result = re.sub(pattern, "", input_string)
    result = ' '.join(result.split())  # Remove extra spaces
    return result

def find_consecutive_int_indices(numbers):
    for index in range(len(numbers) - 1):
        if numbers[index] == 20032 and numbers[index + 1] == 55:
            return index+2

    return None

def eval(answer_file, json_file, result_file, split_str='Answer:'):
    question_type_dict = json.load(open(json_file, 'rb'))['question_type']
    question_type_r_dict = {}
    for k,v in question_type_dict.items():
        question_type_r_dict[v] = k

    split_str = 'Answer:' # ['Answer:', 'A:'][idx]

    all_answers = defaultdict(dict)
    all_index = []
    with open(answer_file, 'r', encoding='utf-8') as reader:
        for line in reader:
            # pdb.set_trace()
            cols = line.strip().split("\t")
            all_index.append(cols[0])
            all_answers[cols[0]]["answer"] = cols[3]
            all_answers[cols[0]]["question"] = cols[1]
            all_answers[cols[0]]["question_type_id"] = cols[-1]

    all_predictions = defaultdict(list)
    all_prediction_probs = defaultdict(list)
    answer_length = None
    answer_index = None
    with open(result_file, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if line.startswith('ST-'):
                src_tokens = ast.literal_eval(line.split('\t')[-1])
                answer_index = find_consecutive_int_indices(src_tokens)
                answer_length = len(src_tokens[answer_index:])
            elif line.startswith('H-'):
                idx = line.split('\t')[0][2:]
                line = line.split('</image>')[-1]
                line = line.split('<image>')[0]
                answer = line.split(split_str)[1].strip()
                answer = clean_special_tokens(answer)

                all_predictions[all_index[int(idx)]].append(answer)
            elif line.startswith('P-'):
                idx = line.split('\t')[0][2:]
                scores_list = list(map(float, line.split('\t')[1].split(" ")))
                answer_scores_list = scores_list[(answer_index-1):]
                mean_score = sum(answer_scores_list) / len(answer_scores_list)
                all_prediction_probs[all_index[int(idx)]].append(mean_score)
    
    correct = 0
    total = 0
    answer_map_dict = {0:"A", 1:"B", 2:"C", 3:"D", 4:"E", 5:"F"}

    question_type_correct = {}
    question_type_total = {}
    for k,v in question_type_r_dict.items():
        question_type_correct[k] = 0
        question_type_total[k] = 0

    for qid in all_answers:
        
        hit = True
        prediction = all_prediction_probs[qid].index(max(all_prediction_probs[qid]))
        if answer_map_dict[prediction] != all_answers[qid]["answer"]:
            hit = False

        if hit:
            correct += 1

        question_type_id = int(all_answers[qid]["question_type_id"])
        question_type_total[question_type_id] += 1
        if hit:
            question_type_correct[question_type_id] += 1

        total += 1

    final_scores = {}
    final_scores["acc"] = correct / total * 100.0
    print("{}\t{}\t{}".format(correct, total, final_scores))
    for k,v in question_type_r_dict.items(): 
        print(k, v, question_type_correct[k] / max(question_type_total[k], 1))



if __name__ == "__main__":
    save_dir = '/path/to/data'
    json_file = f'{save_dir}/SEED-Bench/SEED-Bench.json'
    result_file = sys.argv[1]
    if 'task12' in result_file:
        answer_file = f'{save_dir}/SEED-Bench/seed_bench_task12_pplformat.answer'
    elif 'task10' in result_file:
        answer_file = f'{save_dir}/SEED-Bench/seed_bench_task10_pplformat.answer'
    elif 'task11' in result_file:
        answer_file = f'{save_dir}/SEED-Bench/seed_bench_task11_pplformat.answer'
    else:
        answer_file = f'{save_dir}/SEED-Bench/seed_bench_pplformat.answer'
    eval(answer_file, json_file, result_file)