Tzktz's picture
Upload 7664 files
6fc683c verified
import ast
import json
from tqdm import tqdm, trange
from collections import defaultdict
import re
import os, sys
import pdb
import sys, json
def clean_special_tokens(input_string):
pattern = "<.*?>"
result = re.sub(pattern, "", input_string)
result = ' '.join(result.split()) # Remove extra spaces
return result
def find_consecutive_int_indices(numbers):
for index in range(len(numbers) - 1):
if numbers[index] == 20032 and numbers[index + 1] == 55:
return index+2
return None
def eval(answer_file, json_file, result_file, split_str='Answer:'):
question_type_dict = json.load(open(json_file, 'rb'))['question_type']
question_type_r_dict = {}
for k,v in question_type_dict.items():
question_type_r_dict[v] = k
split_str = 'Answer:' # ['Answer:', 'A:'][idx]
all_answers = defaultdict(dict)
all_index = []
with open(answer_file, 'r', encoding='utf-8') as reader:
for line in reader:
# pdb.set_trace()
cols = line.strip().split("\t")
all_index.append(cols[0])
all_answers[cols[0]]["answer"] = cols[3]
all_answers[cols[0]]["question"] = cols[1]
all_answers[cols[0]]["question_type_id"] = cols[-1]
all_predictions = defaultdict(list)
all_prediction_probs = defaultdict(list)
answer_length = None
answer_index = None
with open(result_file, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if line.startswith('ST-'):
src_tokens = ast.literal_eval(line.split('\t')[-1])
answer_index = find_consecutive_int_indices(src_tokens)
answer_length = len(src_tokens[answer_index:])
elif line.startswith('H-'):
idx = line.split('\t')[0][2:]
line = line.split('</image>')[-1]
line = line.split('<image>')[0]
answer = line.split(split_str)[1].strip()
answer = clean_special_tokens(answer)
all_predictions[all_index[int(idx)]].append(answer)
elif line.startswith('P-'):
idx = line.split('\t')[0][2:]
scores_list = list(map(float, line.split('\t')[1].split(" ")))
answer_scores_list = scores_list[(answer_index-1):]
mean_score = sum(answer_scores_list) / len(answer_scores_list)
all_prediction_probs[all_index[int(idx)]].append(mean_score)
correct = 0
total = 0
answer_map_dict = {0:"A", 1:"B", 2:"C", 3:"D", 4:"E", 5:"F"}
question_type_correct = {}
question_type_total = {}
for k,v in question_type_r_dict.items():
question_type_correct[k] = 0
question_type_total[k] = 0
for qid in all_answers:
hit = True
prediction = all_prediction_probs[qid].index(max(all_prediction_probs[qid]))
if answer_map_dict[prediction] != all_answers[qid]["answer"]:
hit = False
if hit:
correct += 1
question_type_id = int(all_answers[qid]["question_type_id"])
question_type_total[question_type_id] += 1
if hit:
question_type_correct[question_type_id] += 1
total += 1
final_scores = {}
final_scores["acc"] = correct / total * 100.0
print("{}\t{}\t{}".format(correct, total, final_scores))
for k,v in question_type_r_dict.items():
print(k, v, question_type_correct[k] / max(question_type_total[k], 1))
if __name__ == "__main__":
save_dir = '/path/to/data'
json_file = f'{save_dir}/SEED-Bench/SEED-Bench.json'
result_file = sys.argv[1]
if 'task12' in result_file:
answer_file = f'{save_dir}/SEED-Bench/seed_bench_task12_pplformat.answer'
elif 'task10' in result_file:
answer_file = f'{save_dir}/SEED-Bench/seed_bench_task10_pplformat.answer'
elif 'task11' in result_file:
answer_file = f'{save_dir}/SEED-Bench/seed_bench_task11_pplformat.answer'
else:
answer_file = f'{save_dir}/SEED-Bench/seed_bench_pplformat.answer'
eval(answer_file, json_file, result_file)