# coding=utf-8 # Copyleft 2019 project LXRT. import json import numpy as np from torch.utils.data import Dataset from param import args from utils import load_obj_tsv # Load part of the dataset for fast checking. # Notice that here is the number of images instead of the number of data, # which means all related data to the images would be used. TINY_IMG_NUM = 512 FAST_IMG_NUM = 5000 class NLVR2Dataset: """ An NLVR2 data example in json file: { "identifier": "train-10171-0-0", "img0": "train-10171-0-img0", "img1": "train-10171-0-img1", "label": 0, "sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside. ", "uid": "nlvr2_train_0" } """ def __init__(self, splits: str): self.name = splits self.splits = splits.split(',') # Loading datasets to data self.data = [] for split in self.splits: self.data.extend(json.load(open("data/nlvr2/%s.json" % split))) print("Load %d data from split(s) %s." % (len(self.data), self.name)) # List to dict (for evaluation and others) self.id2datum = { datum['uid']: datum for datum in self.data } def __len__(self): return len(self.data) """ An example in obj36 tsv: FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] FIELDNAMES would be keys in the dict returned by load_obj_tsv. """ class NLVR2TorchDataset(Dataset): def __init__(self, dataset: NLVR2Dataset): super().__init__() self.raw_dataset = dataset if args.tiny: topk = TINY_IMG_NUM elif args.fast: topk = FAST_IMG_NUM else: topk = -1 # Loading detection features to img_data img_data = [] if 'train' in dataset.splits: img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj36.tsv', topk=topk)) if 'valid' in dataset.splits: img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj36.tsv', topk=topk)) if 'test' in dataset.name: img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj36.tsv', topk=topk)) self.imgid2img = {} for img_datum in img_data: self.imgid2img[img_datum['img_id']] = img_datum # Filter out the dataset self.data = [] for datum in self.raw_dataset.data: if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img: self.data.append(datum) print("Use %d data in torch dataset" % (len(self.data))) print() def __len__(self): return len(self.data) def __getitem__(self, item: int): datum = self.data[item] ques_id = datum['uid'] ques = datum['sent'] # Get image info boxes2 = [] feats2 = [] for key in ['img0', 'img1']: img_id = datum[key] img_info = self.imgid2img[img_id] boxes = img_info['boxes'].copy() feats = img_info['features'].copy() assert len(boxes) == len(feats) # Normalize the boxes (to 0 ~ 1) img_h, img_w = img_info['img_h'], img_info['img_w'] boxes[..., (0, 2)] /= img_w boxes[..., (1, 3)] /= img_h np.testing.assert_array_less(boxes, 1+1e-5) np.testing.assert_array_less(-boxes, 0+1e-5) boxes2.append(boxes) feats2.append(feats) feats = np.stack(feats2) boxes = np.stack(boxes2) # Create target if 'label' in datum: label = datum['label'] return ques_id, feats, boxes, ques, label else: return ques_id, feats, boxes, ques class NLVR2Evaluator: def __init__(self, dataset: NLVR2Dataset): self.dataset = dataset def evaluate(self, quesid2ans: dict): score = 0. for quesid, ans in quesid2ans.items(): datum = self.dataset.id2datum[quesid] label = datum['label'] if ans == label: score += 1 return score / len(quesid2ans) def dump_result(self, quesid2ans: dict, path): """ Dump result to a CSV file, which is compatible with NLVR2 evaluation system. NLVR2 CSV file requirement: Each line contains: identifier, answer :param quesid2ans: nlvr2 uid to ans (either "True" or "False") :param path: The desired path of saved file. :return: """ with open(path, 'w') as f: for uid, ans in quesid2ans.items(): idt = self.dataset.id2datum[uid]["identifier"] ans = 'True' if ans == 1 else 'False' f.write("%s,%s\n" % (idt, ans))