import argparse import json import os import random import re def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--base-dir", type=str) parser.add_argument("--result-file", type=str) parser.add_argument("--output-file", type=str) parser.add_argument("--output-result", type=str) parser.add_argument("--split", type=str, default="test") parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: image_id = result["question_id"] caption = result["text"] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ if prediction in options[: len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[ args.split ] problems = json.load(open(os.path.join(base_dir, "problems.json"))) predictions = [json.loads(line) for line in open(args.result_file)] predictions = {pred["question_id"]: pred for pred in predictions} split_problems = {idx: problems[idx] for idx in split_indices} results = {"correct": [], "incorrect": []} sqa_results = {} sqa_results["acc"] = None sqa_results["correct"] = None sqa_results["count"] = None sqa_results["results"] = {} sqa_results["outputs"] = {} for prob_id, prob in split_problems.items(): if prob_id not in predictions: continue pred = predictions[prob_id] pred_text = pred["text"] pattern = re.compile(r"The answer is ([A-Z]).") res = pattern.findall(pred_text) if len(res) == 1: answer = res[0] # 'A', 'B', ... else: answer = "FAILED" pred_idx = get_pred_idx(answer, prob["choices"], args.options) analysis = { "question_id": prob_id, "parsed_ans": answer, "ground_truth": args.options[prob["answer"]], "question": pred["prompt"], "pred": pred_text, "is_multimodal": "" in pred["prompt"], } sqa_results["results"][prob_id] = get_pred_idx( answer, prob["choices"], args.options ) sqa_results["outputs"][prob_id] = pred_text if pred_idx == prob["answer"]: results["correct"].append(analysis) else: results["incorrect"].append(analysis) correct = len(results["correct"]) total = len(results["correct"]) + len(results["incorrect"]) print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%") sqa_results["acc"] = correct / total * 100 sqa_results["correct"] = correct sqa_results["count"] = total with open(args.output_file, "w") as f: json.dump(results, f, indent=2) with open(args.output_result, "w") as f: json.dump(sqa_results, f, indent=2)