File size: 3,262 Bytes
5885496
 
 
 
3d9fba4
5885496
 
 
 
3d9fba4
 
 
 
 
 
5885496
 
 
 
 
 
3d9fba4
 
5885496
 
 
 
 
 
 
 
3d9fba4
5885496
 
 
 
 
 
 
 
 
3d9fba4
 
 
5885496
 
3d9fba4
5885496
 
3d9fba4
5885496
3d9fba4
 
 
 
 
5885496
 
 
 
 
3d9fba4
5885496
3d9fba4
5885496
 
 
 
 
 
3d9fba4
5885496
 
3d9fba4
 
 
 
 
 
5885496
 
3d9fba4
 
 
 
5885496
3d9fba4
 
5885496
3d9fba4
5885496
3d9fba4
 
 
5885496
3d9fba4
 
 
5885496
3d9fba4
5885496
3d9fba4
5885496
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import json
import os
import random
import re


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-dir", type=str)
    parser.add_argument("--result-file", type=str)
    parser.add_argument("--output-file", type=str)
    parser.add_argument("--output-result", type=str)
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
    return parser.parse_args()


def convert_caps(results):
    fakecaps = []
    for result in results:
        image_id = result["question_id"]
        caption = result["text"]
        fakecaps.append({"image_id": int(image_id), "caption": caption})
    return fakecaps


def get_pred_idx(prediction, choices, options):
    """
    Get the index (e.g. 2) from the prediction (e.g. 'C')
    """
    if prediction in options[: len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))


if __name__ == "__main__":
    args = get_args()

    base_dir = args.base_dir
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
        args.split
    ]
    problems = json.load(open(os.path.join(base_dir, "problems.json")))
    predictions = [json.loads(line) for line in open(args.result_file)]
    predictions = {pred["question_id"]: pred for pred in predictions}
    split_problems = {idx: problems[idx] for idx in split_indices}

    results = {"correct": [], "incorrect": []}
    sqa_results = {}
    sqa_results["acc"] = None
    sqa_results["correct"] = None
    sqa_results["count"] = None
    sqa_results["results"] = {}
    sqa_results["outputs"] = {}

    for prob_id, prob in split_problems.items():
        if prob_id not in predictions:
            continue
        pred = predictions[prob_id]
        pred_text = pred["text"]

        pattern = re.compile(r"The answer is ([A-Z]).")
        res = pattern.findall(pred_text)
        if len(res) == 1:
            answer = res[0]  # 'A', 'B', ...
        else:
            answer = "FAILED"

        pred_idx = get_pred_idx(answer, prob["choices"], args.options)

        analysis = {
            "question_id": prob_id,
            "parsed_ans": answer,
            "ground_truth": args.options[prob["answer"]],
            "question": pred["prompt"],
            "pred": pred_text,
            "is_multimodal": "<image>" in pred["prompt"],
        }

        sqa_results["results"][prob_id] = get_pred_idx(
            answer, prob["choices"], args.options
        )
        sqa_results["outputs"][prob_id] = pred_text

        if pred_idx == prob["answer"]:
            results["correct"].append(analysis)
        else:
            results["incorrect"].append(analysis)

    correct = len(results["correct"])
    total = len(results["correct"]) + len(results["incorrect"])
    print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")

    sqa_results["acc"] = correct / total * 100
    sqa_results["correct"] = correct
    sqa_results["count"] = total

    with open(args.output_file, "w") as f:
        json.dump(results, f, indent=2)
    with open(args.output_result, "w") as f:
        json.dump(sqa_results, f, indent=2)