|
""" |
|
Copyright (c) 2022, salesforce.com, inc. |
|
All rights reserved. |
|
SPDX-License-Identifier: BSD-3-Clause |
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
""" |
|
|
|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
|
|
from PIL import Image |
|
from PIL import ImageFile |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset |
|
|
|
COCOCapDataset = COCOCaptionDataset |
|
|
|
|
|
|
|
|
|
|
|
class COCOCapEvalDataset(CaptionEvalDataset): |
|
def __init__(self, vis_processor, text_processor, vis_root, ann_paths): |
|
""" |
|
vis_root (string): Root directory of images (e.g. coco/images/) |
|
ann_root (string): directory to store the annotation file |
|
split (string): val or test |
|
""" |
|
super().__init__(vis_processor, text_processor, vis_root, ann_paths) |
|
|
|
def __getitem__(self, index): |
|
ann = self.annotation[index] |
|
|
|
image_path = os.path.join(self.vis_root, ann["image"]) |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
image = self.vis_processor(image) |
|
|
|
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1] |
|
|
|
return { |
|
"image": image, |
|
"image_id": img_id, |
|
"instance_id": ann["instance_id"], |
|
} |
|
|
|
|
|
class NoCapsEvalDataset(CaptionEvalDataset): |
|
def __init__(self, vis_processor, text_processor, vis_root, ann_paths): |
|
""" |
|
vis_root (string): Root directory of images (e.g. coco/images/) |
|
ann_root (string): directory to store the annotation file |
|
split (string): val or test |
|
""" |
|
super().__init__(vis_processor, text_processor, vis_root, ann_paths) |
|
|
|
def __getitem__(self, index): |
|
ann = self.annotation[index] |
|
|
|
image_path = os.path.join(self.vis_root, ann["image"]) |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
image = self.vis_processor(image) |
|
|
|
img_id = ann["img_id"] |
|
|
|
return { |
|
"image": image, |
|
"image_id": img_id, |
|
"instance_id": ann["instance_id"], |
|
} |
|
|
|
|
|
class RefCOCOEvalData(torch.utils.data.Dataset): |
|
def __init__(self, loaded_data, vis_processor, root_path): |
|
self.loaded_data = loaded_data |
|
self.root_path = root_path |
|
self.vis_processor = vis_processor |
|
|
|
def __len__(self): |
|
return len(self.loaded_data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.loaded_data[idx] |
|
img_id = data['img_id'] |
|
sent = data['sents'] |
|
image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg') |
|
image = Image.open(image_path).convert('RGB') |
|
image = self.vis_processor(image) |
|
question = f"[refer] where is {sent}?" |
|
return image, question, img_id |
|
|
|
class EvalCaptionData(torch.utils.data.Dataset): |
|
def __init__(self, loaded_data, vis_processor, root_path): |
|
self.loaded_data = loaded_data |
|
self.root_path = root_path |
|
self.vis_processor = vis_processor |
|
ann = dict() |
|
for item in self.loaded_data: |
|
image_id = item['image_id'] |
|
ann[image_id] = item['image'] |
|
self.ann = [{'image_id':image_id, 'image': ann[image_id]} for image_id in ann] |
|
|
|
def __len__(self): |
|
return len(self.ann) |
|
|
|
def __getitem__(self, idx): |
|
data = self.ann[idx] |
|
image_id = data['image_id'] |
|
img_file = data['image'].split('/')[-1] |
|
image_path = os.path.join(self.root_path, img_file) |
|
image = Image.open(image_path).convert('RGB') |
|
|
|
image = self.vis_processor(image) |
|
question = f"[caption] please describe this image?" |
|
return image, question, image_id |
|
|