import json import os import re from torch.utils.data import Dataset def prompt_processor(prompt): if prompt.startswith('OCR tokens: '): pattern = r"Question: (.*?) Short answer:" match = re.search(pattern, prompt, re.DOTALL) question = match.group(1) elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: if prompt.startswith('Reference OCR token:'): question = prompt.split('\n')[1] else: question = prompt.split('\n')[0] elif len(prompt.split('\n')) == 2: question = prompt.split('\n')[0] else: assert False return question.lower() class textVQADataset(Dataset): def __init__( self, image_dir="./downloads/TextVQA/train_images", ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json", ): self.data = json.load(open(ann_path, "r"))["data"] self.image_dir = image_dir def __len__(self): return len(self.data) def __getitem__(self, idx): question = self.data[idx]['question'] answers = self.data[idx]['answers'] img_id = self.data[idx]['image_id'] qid = self.data[idx]['question_id'] img_path = os.path.join(self.image_dir, f"{img_id}.jpg") item = { "question_id": qid, "image_path": img_path, "question": question, "gt_answers": answers } return item class docVQADataset(Dataset): def __init__( self, image_dir= "./downloads/DocVQA/spdocvqa_images", ann_path= "./downloads/DocVQA/val_v1.0_withQT.json", ocr_token_path=None ): self.data = json.load(open(ann_path, "r"))["data"] self.image_dir = image_dir self.ann_path = ann_path if ocr_token_path: self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]} def __len__(self): return len(self.data) def __getitem__(self, idx): question_id = self.data[idx]['questionId'] relative_img_path = self.data[idx]['image'] corrected_relative_img_path = relative_img_path.replace("documents", "images") img_path = os.path.join(self.image_dir, corrected_relative_img_path) question = self.data[idx]['question'] answers = self.data[idx]['answers'] question_type = self.data[idx]['question_types'] return { "question_id": question_id, "image_path": img_path, "question": question, "gt_answers": answers, 'question_type': question_type, } class docVQATESTDataset(Dataset): def __init__( self, image_dir= "./downloads/DocVQA/spdocvqa_images", ann_path= "./downloads/DocVQA/test_v1.0.json", ocr_token_path=None ): self.data = json.load(open(ann_path, "r"))["data"] self.image_dir = image_dir self.ann_path = ann_path def __len__(self): return len(self.data) def __getitem__(self, idx): question_id = self.data[idx]['questionId'] relative_img_path = self.data[idx]['image'] corrected_relative_img_path = relative_img_path.replace("documents", "images") img_path = os.path.join(self.image_dir, corrected_relative_img_path) question = self.data[idx]['question'] return { "question_id": question_id, "image_path": img_path, "question": question, "gt_answers": "", 'question_type': "", }