import ast import json from tqdm import tqdm, trange from collections import defaultdict import re import os, sys import pdb import sys, json def clean_special_tokens(input_string): pattern = "<.*?>" result = re.sub(pattern, "", input_string) result = ' '.join(result.split()) # Remove extra spaces return result def find_consecutive_int_indices(numbers): for index in range(len(numbers) - 1): if numbers[index] == 20032 and numbers[index + 1] == 55: return index+2 return None def eval(answer_file, json_file, result_file, split_str='Answer:'): question_type_dict = json.load(open(json_file, 'rb'))['question_type'] question_type_r_dict = {} for k,v in question_type_dict.items(): question_type_r_dict[v] = k split_str = 'Answer:' # ['Answer:', 'A:'][idx] all_answers = defaultdict(dict) all_index = [] with open(answer_file, 'r', encoding='utf-8') as reader: for line in reader: # pdb.set_trace() cols = line.strip().split("\t") all_index.append(cols[0]) all_answers[cols[0]]["answer"] = cols[3] all_answers[cols[0]]["question"] = cols[1] all_answers[cols[0]]["question_type_id"] = cols[-1] all_predictions = defaultdict(list) all_prediction_probs = defaultdict(list) answer_length = None answer_index = None with open(result_file, 'r', encoding='utf-8') as f: for i, line in enumerate(f): if line.startswith('ST-'): src_tokens = ast.literal_eval(line.split('\t')[-1]) answer_index = find_consecutive_int_indices(src_tokens) answer_length = len(src_tokens[answer_index:]) elif line.startswith('H-'): idx = line.split('\t')[0][2:] line = line.split('')[-1] line = line.split('')[0] answer = line.split(split_str)[1].strip() answer = clean_special_tokens(answer) all_predictions[all_index[int(idx)]].append(answer) elif line.startswith('P-'): idx = line.split('\t')[0][2:] scores_list = list(map(float, line.split('\t')[1].split(" "))) answer_scores_list = scores_list[(answer_index-1):] mean_score = sum(answer_scores_list) / len(answer_scores_list) all_prediction_probs[all_index[int(idx)]].append(mean_score) correct = 0 total = 0 answer_map_dict = {0:"A", 1:"B", 2:"C", 3:"D", 4:"E", 5:"F"} question_type_correct = {} question_type_total = {} for k,v in question_type_r_dict.items(): question_type_correct[k] = 0 question_type_total[k] = 0 for qid in all_answers: hit = True prediction = all_prediction_probs[qid].index(max(all_prediction_probs[qid])) if answer_map_dict[prediction] != all_answers[qid]["answer"]: hit = False if hit: correct += 1 question_type_id = int(all_answers[qid]["question_type_id"]) question_type_total[question_type_id] += 1 if hit: question_type_correct[question_type_id] += 1 total += 1 final_scores = {} final_scores["acc"] = correct / total * 100.0 print("{}\t{}\t{}".format(correct, total, final_scores)) for k,v in question_type_r_dict.items(): print(k, v, question_type_correct[k] / max(question_type_total[k], 1)) if __name__ == "__main__": save_dir = '/path/to/data' json_file = f'{save_dir}/SEED-Bench/SEED-Bench.json' result_file = sys.argv[1] if 'task12' in result_file: answer_file = f'{save_dir}/SEED-Bench/seed_bench_task12_pplformat.answer' elif 'task10' in result_file: answer_file = f'{save_dir}/SEED-Bench/seed_bench_task10_pplformat.answer' elif 'task11' in result_file: answer_file = f'{save_dir}/SEED-Bench/seed_bench_task11_pplformat.answer' else: answer_file = f'{save_dir}/SEED-Bench/seed_bench_pplformat.answer' eval(answer_file, json_file, result_file)