John6666's picture
Upload 351 files
e84842d verified
raw
history blame
4.96 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from collections import OrderedDict
import json
import os
import torch
from PIL import Image
from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"question": ann["question"],
"question_id": ann["question_id"],
"direct_answers": "; ".join(ann["direct_answers"]),
"choices": "; ".join(ann["choices"]),
"correct_choice": ann["choices"][ann["correct_choice_idx"]],
"image": sample["image"],
}
)
class AOKVQADataset(VQADataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(ann["question"])
answer_key = "direct_answers"
answer_weight = {}
for answer in ann[answer_key]:
if answer in answer_weight.keys():
answer_weight[answer] += 1 / len(ann[answer_key])
else:
answer_weight[answer] = 1 / len(ann[answer_key])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
return {
"image": image,
"text_input": question,
"answers": answers,
"weights": weights,
}
class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.annotation = json.load(open(ann_paths[0]))
answer_list_path = ann_paths[1]
if os.path.exists(answer_list_path):
self.answer_list = json.load(open(answer_list_path))
else:
self.answer_list = None
try:
self.coco_fmt_qust_file = ann_paths[2]
self.coco_fmt_anno_file = ann_paths[3]
except IndexError:
self.coco_fmt_qust_file = None
self.coco_fmt_anno_file = None
self.vis_processor = vis_processor
self.text_processor = text_processor
self._add_instance_ids()
def collater(self, samples):
(
image_list,
question_list,
question_id_list,
instance_id_list,
choices_list,
correct_choice_idx_list,
direct_answers_list,
) = ([], [], [], [], [], [], [])
for sample in samples:
image_list.append(sample["image"])
question_list.append(sample["text_input"])
question_id_list.append(sample["question_id"])
instance_id_list.append(sample["instance_id"])
choices_list.append(sample["choices"])
correct_choice_idx_list.append(sample["correct_choice_idx"])
direct_answers_list.append(sample["direct_answers"])
return {
"image": torch.stack(image_list, dim=0),
"text_input": question_list,
"question_id": question_id_list,
"instance_id": instance_id_list,
"choices": choices_list,
"correct_choice_idx": correct_choice_idx_list,
"direct_answers": direct_answers_list,
}
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(ann["question"])
choices = ann["choices"]
if "correct_choice_idx" in ann:
correct_choice_idx = ann["correct_choice_idx"]
else:
correct_choice_idx = None
if "direct_answers" in ann:
direct_answers = ann["direct_answers"]
else:
direct_answers = None
return {
"image": image,
"text_input": question,
"question_id": ann["question_id"],
"instance_id": ann["instance_id"],
"choices": choices,
"correct_choice_idx": correct_choice_idx,
"direct_answers": direct_answers,
}