Binder / utils /mmqa /qimc.py
Timothyxxx
Add missed MMQA
7de3018
raw
history blame
1.77 kB
import json
import os
import pandas as pd
ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../")
class Question_Image_Match_Classifier(object):
"""result are from a T5-3b model finetuned on train set of MMQA."""
def __init__(self):
self.whether_retrieve_image = None
self.qi_pairs_should_retrieve = None
self.load_retrieve_info()
self.caption_info = None
with open(os.path.join(ROOT_DIR, "utils", "mmqa", "mmqa_captions.json"), "r") as f:
self.caption_info = json.load(f)
def load_retrieve_info(self):
df_qc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qc_mmqa_dev.csv"))
whether_retrieve_image = {}
for index, row in df_qc.iterrows():
_id = row['id']
prediction = row['prediction']
whether_retrieve_image[_id] = True if prediction == "['yes']" else False
self.whether_retrieve_image = whether_retrieve_image
df_qimc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qimc_mmqa_dev.csv"))
qi_pairs_should_retrieve = {}
for index, row in df_qimc.iterrows():
qa = row['question'].lower()
prediction = row['prediction']
qi_pairs_should_retrieve[qa] = True if prediction == "['yes']" else False
self.qi_pairs_should_retrieve = qi_pairs_should_retrieve
def judge_match(self, _id, question, pic):
# fixme: hardcode since it is done in pipeline, change that in the future
if not self.whether_retrieve_image[_id]:
return False
image_caption = self.caption_info[os.path.split(pic)[-1].split(".")[0]]
return self.qi_pairs_should_retrieve['qa: {} \n{}'.format(question.lower(), image_caption.lower())]