File size: 3,716 Bytes
5885496
 
 
 
3d9fba4
5885496
 
 
 
 
3d9fba4
 
 
 
 
5885496
 
 
 
 
 
3d9fba4
 
5885496
 
 
 
 
 
 
 
3d9fba4
5885496
 
 
 
 
 
 
 
 
3d9fba4
 
 
5885496
 
3d9fba4
5885496
 
3d9fba4
5885496
 
 
 
 
 
 
 
3d9fba4
5885496
 
3d9fba4
5885496
 
 
 
 
 
 
 
 
 
 
3d9fba4
 
5885496
3d9fba4
 
5885496
 
 
 
 
 
 
 
 
 
 
 
 
3d9fba4
 
5885496
3d9fba4
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import json
import os
import random
import re
from collections import defaultdict


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-dir", type=str)
    parser.add_argument("--gpt4-result", type=str)
    parser.add_argument("--our-result", type=str)
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
    return parser.parse_args()


def convert_caps(results):
    fakecaps = []
    for result in results:
        image_id = result["question_id"]
        caption = result["text"]
        fakecaps.append({"image_id": int(image_id), "caption": caption})
    return fakecaps


def get_pred_idx(prediction, choices, options):
    """
    Get the index (e.g. 2) from the prediction (e.g. 'C')
    """
    if prediction in options[: len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))


if __name__ == "__main__":
    args = get_args()

    base_dir = args.base_dir
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
        args.split
    ]
    problems = json.load(open(os.path.join(base_dir, "problems.json")))
    our_predictions = [json.loads(line) for line in open(args.our_result)]
    our_predictions = {pred["question_id"]: pred for pred in our_predictions}
    split_problems = {idx: problems[idx] for idx in split_indices}

    gpt4_predictions = json.load(open(args.gpt4_result))["outputs"]

    results = defaultdict(lambda: 0)

    for prob_id, prob in split_problems.items():
        if prob_id not in our_predictions:
            continue
        if prob_id not in gpt4_predictions:
            continue
        our_pred = our_predictions[prob_id]["text"]
        gpt4_pred = gpt4_predictions[prob_id]

        pattern = re.compile(r"The answer is ([A-Z]).")
        our_res = pattern.findall(our_pred)
        if len(our_res) == 1:
            our_answer = our_res[0]  # 'A', 'B', ...
        else:
            our_answer = "FAILED"
        gpt4_res = pattern.findall(gpt4_pred)
        if len(gpt4_res) == 1:
            gpt4_answer = gpt4_res[0]  # 'A', 'B', ...
        else:
            gpt4_answer = "FAILED"

        our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options)
        gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options)

        if gpt4_answer == "FAILED":
            results["gpt4_failed"] += 1
            # continue
            gpt4_pred_idx = our_pred_idx
            # if our_pred_idx != prob['answer']:
            #     print(our_predictions[prob_id]['prompt'])
            #     print('-----------------')
            #     print(f'LECTURE: {prob["lecture"]}')
            #     print(f'SOLUTION: {prob["solution"]}')
            #     print('=====================')
        else:
            # continue
            pass
        # gpt4_pred_idx = our_pred_idx

        if gpt4_pred_idx == prob["answer"]:
            results["correct"] += 1
        else:
            results["incorrect"] += 1

        if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]:
            results["correct_upperbound"] += 1

    correct = results["correct"]
    total = results["correct"] + results["incorrect"]
    print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")
    print(
        f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%'
    )
    print(
        f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%'
    )