Spaces:
Paused
Paused
File size: 3,262 Bytes
5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 3d9fba4 5885496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import argparse
import json
import os
import random
import re
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base-dir", type=str)
parser.add_argument("--result-file", type=str)
parser.add_argument("--output-file", type=str)
parser.add_argument("--output-result", type=str)
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
return parser.parse_args()
def convert_caps(results):
fakecaps = []
for result in results:
image_id = result["question_id"]
caption = result["text"]
fakecaps.append({"image_id": int(image_id), "caption": caption})
return fakecaps
def get_pred_idx(prediction, choices, options):
"""
Get the index (e.g. 2) from the prediction (e.g. 'C')
"""
if prediction in options[: len(choices)]:
return options.index(prediction)
else:
return random.choice(range(len(choices)))
if __name__ == "__main__":
args = get_args()
base_dir = args.base_dir
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
args.split
]
problems = json.load(open(os.path.join(base_dir, "problems.json")))
predictions = [json.loads(line) for line in open(args.result_file)]
predictions = {pred["question_id"]: pred for pred in predictions}
split_problems = {idx: problems[idx] for idx in split_indices}
results = {"correct": [], "incorrect": []}
sqa_results = {}
sqa_results["acc"] = None
sqa_results["correct"] = None
sqa_results["count"] = None
sqa_results["results"] = {}
sqa_results["outputs"] = {}
for prob_id, prob in split_problems.items():
if prob_id not in predictions:
continue
pred = predictions[prob_id]
pred_text = pred["text"]
pattern = re.compile(r"The answer is ([A-Z]).")
res = pattern.findall(pred_text)
if len(res) == 1:
answer = res[0] # 'A', 'B', ...
else:
answer = "FAILED"
pred_idx = get_pred_idx(answer, prob["choices"], args.options)
analysis = {
"question_id": prob_id,
"parsed_ans": answer,
"ground_truth": args.options[prob["answer"]],
"question": pred["prompt"],
"pred": pred_text,
"is_multimodal": "<image>" in pred["prompt"],
}
sqa_results["results"][prob_id] = get_pred_idx(
answer, prob["choices"], args.options
)
sqa_results["outputs"][prob_id] = pred_text
if pred_idx == prob["answer"]:
results["correct"].append(analysis)
else:
results["incorrect"].append(analysis)
correct = len(results["correct"])
total = len(results["correct"]) + len(results["incorrect"])
print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")
sqa_results["acc"] = correct / total * 100
sqa_results["correct"] = correct
sqa_results["count"] = total
with open(args.output_file, "w") as f:
json.dump(results, f, indent=2)
with open(args.output_result, "w") as f:
json.dump(sqa_results, f, indent=2)
|