Spaces:
Runtime error
Runtime error
import json | |
import os | |
import pandas as pd | |
ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../") | |
class Question_Image_Match_Classifier(object): | |
"""result are from a T5-3b model finetuned on train set of MMQA.""" | |
def __init__(self): | |
self.whether_retrieve_image = None | |
self.qi_pairs_should_retrieve = None | |
self.load_retrieve_info() | |
self.caption_info = None | |
with open(os.path.join(ROOT_DIR, "utils", "mmqa", "mmqa_captions.json"), "r") as f: | |
self.caption_info = json.load(f) | |
def load_retrieve_info(self): | |
df_qc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qc_mmqa_dev.csv")) | |
whether_retrieve_image = {} | |
for index, row in df_qc.iterrows(): | |
_id = row['id'] | |
prediction = row['prediction'] | |
whether_retrieve_image[_id] = True if prediction == "['yes']" else False | |
self.whether_retrieve_image = whether_retrieve_image | |
df_qimc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qimc_mmqa_dev.csv")) | |
qi_pairs_should_retrieve = {} | |
for index, row in df_qimc.iterrows(): | |
qa = row['question'].lower() | |
prediction = row['prediction'] | |
qi_pairs_should_retrieve[qa] = True if prediction == "['yes']" else False | |
self.qi_pairs_should_retrieve = qi_pairs_should_retrieve | |
def judge_match(self, _id, question, pic): | |
# fixme: hardcode since it is done in pipeline, change that in the future | |
if not self.whether_retrieve_image[_id]: | |
return False | |
image_caption = self.caption_info[os.path.split(pic)[-1].split(".")[0]] | |
return self.qi_pairs_should_retrieve['qa: {} \n{}'.format(question.lower(), image_caption.lower())] |