import os
import json
import pickle
import random
import time
import itertools
import numpy as np
from PIL import Image
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from torch.utils.data import Dataset
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
class SingleSlideVQADataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.data = self.create_data(ann_path)
# self.instruction_pool = [
# "###Human:
{}###Assistant: ",
# "###Human:
From this slide, {}###Assistant: ",
# ]
self.instruction_pool = [
"
{}",
"
From this slide, {}",
]
def create_data(self, ann_path):
with open(ann_path, 'r') as f:
samples = f.readlines()
data = []
for sample in samples:
sample = json.loads(sample)
if len(sample['evidence_pages']) != 1: continue # skip questions that need more than one slide page
page = sample['evidence_pages'][0]
image_name = 'slide_{}_1024.jpg'.format(page)
# assert [int(image_name.split('-')[-2]) for image_name in image_names] == list(range(1, 21)) # check the format
image_path = os.path.join(sample['deck_name'], image_name)
data.append({
'qa_id': sample['qa_id'],
'question': sample['question'],
'answer': sample['answer'],
'image_path': image_path
})
print("single slide ",len(data))
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
image = self.vis_processor(image)
# instruction = self.text_processor(sample["question"])
instruction = random.choice(self.instruction_pool).format(self.text_processor(sample["question"]))
# instruction = random.choice(self.instruction_pool).format(self.text_processor(sample["question"]))
return {
"image": image,
"instruction_input": instruction,
"answer": sample['answer'],
"qa_id": sample['qa_id'],
}
class OCRVQADataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.data = self.create_data(ann_path)
self.instruction_pool =[
"Q: {} A: ",
]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, 'r') as f:
data = json.load(f)
for k in data.keys():
if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]['imageURL'])[1]
imageFile = k + ext
assert len(data[k]['questions']) == len(data[k]['answers'])
for q, a in zip(data[k]['questions'], data[k]['answers']):
processed_data.append(
{'question': q,
'answer': a,
'image_path': imageFile,
'image_id': k,
'title': data[k]['title'],
'genre': data[k]['genre'],
}
)
print("ocr vqa", len(processed_data))
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(sample["question"])
answer = self.text_processor(sample["answer"])
instruction = random.choice(self.instruction_pool).format(question)
instruction = "
{} ".format(instruction)
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": sample['image_id']
}
class TextOCRDataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.data = self.create_data(ann_path)
self.instruction_pool = [
"
[OCR] {}"
]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, 'r') as f:
data = json.load(f)
for k in data["anns"].keys():
# ext = os.path.splitext(data[k]['imageURL'])[1]
imageFile = data["anns"][k]["image_id"]+".jpg"
bbox = data["anns"][k]["bbox"]
text = data["anns"][k]["utf8_string"]
# assert len(data[k]['questions']) == len(data[k]['answers'])
# for q, a in zip(data[k]['questions'], data[k]['answers']):
processed_data.append(
{'bbox': bbox,
'answer': text,
'image_path': imageFile,
'image_id': k,
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
width, height = image.size
image = self.vis_processor(image)
new_bbox =""
image_size = 100
bbox = sample['bbox']
for index in range(len(bbox)):
x1 = int(bbox[0]/width*image_size)
y1 = int(bbox[1]/height*image_size)
x2 = x1 + int(bbox[2]/width*image_size)
y2 = y1 + int(bbox[3]/height*image_size)
assert x1>=0 and x1<=image_size
assert x2>=0 and x2<=image_size
assert y1>=0 and y1<=image_size
assert y2>=0 and y2<=image_size
new_bbox = " <"+str(x1)+"><"+str(y1)+"><"+str(x2)+"><"+str(y2)+">"
instruction = random.choice(self.instruction_pool).format(new_bbox)
return {
"image": image,
"instruction_input": instruction,
"answer": sample['answer'],
"image_id": sample['image_id']
}
class PlotVQADataset(Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.data = self.create_data(ann_path)
self.instruction_pool = [
'Q: {} A:',
]
def create_data(self, ann_path):
processed_data = []
with open(ann_path, 'r') as f:
data = json.load(f)
for da in data["qa_pairs"]:
# ext = os.path.splitext(data[k]['imageURL'])[1]
imageFile = str(da["image_index"])+".png"
question = da["question_string"]
answer = str(da["answer"])
# assert len(data[k]['questions']) == len(data[k]['answers'])
# for q, a in zip(data[k]['questions'], data[k]['answers']):
processed_data.append(
{'question': question,
'answer': answer,
'image_path': imageFile,
'image_id': str(da["image_index"]),
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample['image_path'])).convert("RGB")
# width, height = image.size
image = self.vis_processor(image)
# image_shape = image.shape
instruction = "
{} ".format(sample["question"])
instruction = random.choice(self.instruction_pool).format(instruction)
answer = sample["answer"]
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": sample['image_id']
}