WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
8.85 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
from collections import defaultdict
import json
import random
import numpy as np
from torch.utils.data import Dataset
from param import args
from pretrain.qa_answer_table import AnswerTable
from utils import load_obj_tsv
TINY_IMG_NUM = 500
FAST_IMG_NUM = 5000
Split2ImgFeatPath = {
'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv',
'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv',
}
class InputExample(object):
"""A single training/test example for the language model."""
def __init__(self, uid, sent, visual_feats=None,
obj_labels=None, attr_labels=None,
is_matched=None, label=None):
self.uid = uid
self.sent = sent
self.visual_feats = visual_feats
self.obj_labels = obj_labels
self.attr_labels = attr_labels
self.is_matched = is_matched # whether the visual and obj matched
self.label = label
class LXMERTDataset:
def __init__(self, splits: str, qa_sets=None):
"""
:param splits: The data sources to be loaded
:param qa_sets: if None, no action
o.w., only takes the answers appearing in these dsets
and remove all unlabeled data (MSCOCO captions)
"""
self.name = splits
self.sources = splits.split(',')
# Loading datasets to data
self.data = []
for source in self.sources:
self.data.extend(json.load(open("data/lxmert/%s.json" % source)))
print("Load %d data from %s" % (len(self.data), self.name))
# Create answer table according to the qa_sets
self.answer_table = AnswerTable(qa_sets)
print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))
# Modify the answers
for datum in self.data:
labelf = datum['labelf']
for cat, labels in labelf.items():
for label in labels:
for ans in list(label.keys()):
new_ans = self.answer_table.convert_ans(ans)
if self.answer_table.used(new_ans):
if ans != new_ans:
label[new_ans] = label.pop(ans)
else:
label.pop(ans)
def __len__(self):
return len(self.data)
def make_uid(img_id, dset, sent_idx):
return "%s_%s_%03d" % (img_id, dset, sent_idx),
"""
Example in obj tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
"""
class LXMERTTorchDataset(Dataset):
def __init__(self, dataset: LXMERTDataset, topk=-1):
super().__init__()
self.raw_dataset = dataset
self.task_matched = args.task_matched
if args.tiny:
topk = TINY_IMG_NUM
elif args.fast:
topk = FAST_IMG_NUM
# Load the dataset
img_data = []
for source in self.raw_dataset.sources:
img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))
self.imgid2img = {}
for img_datum in img_data:
self.imgid2img[img_datum['img_id']] = img_datum
# Filter out the dataset
used_data = []
for datum in self.raw_dataset.data:
if datum['img_id'] in self.imgid2img:
used_data.append(datum)
# Flatten the dataset (into one sent + one image entries)
self.data = []
for datum in used_data:
sentf = datum['sentf']
for sents_cat, sents in sentf.items():
if sents_cat in datum['labelf']:
labels = datum['labelf'][sents_cat]
else:
labels = None
for sent_idx, sent in enumerate(sents):
new_datum = {
'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
'img_id': datum['img_id'],
'sent': sent
}
if labels is not None:
new_datum['label'] = labels[sent_idx]
self.data.append(new_datum)
print("Use %d data in torch dataset" % (len(self.data)))
def __len__(self):
return len(self.data)
def random_feat(self):
"""Get a random obj feat from the dataset."""
datum = self.data[random.randint(0, len(self.data)-1)]
img_id = datum['img_id']
img_info = self.imgid2img[img_id]
feat = img_info['features'][random.randint(0, 35)]
return feat
def __getitem__(self, item: int):
datum = self.data[item]
uid = datum['uid']
img_id = datum['img_id']
# Get image info
img_info = self.imgid2img[img_id]
obj_num = img_info['num_boxes']
feats = img_info['features'].copy()
boxes = img_info['boxes'].copy()
obj_labels = img_info['objects_id'].copy()
obj_confs = img_info['objects_conf'].copy()
attr_labels = img_info['attrs_id'].copy()
attr_confs = img_info['attrs_conf'].copy()
assert obj_num == len(boxes) == len(feats)
# Normalize the boxes (to 0 ~ 1)
img_h, img_w = img_info['img_h'], img_info['img_w']
boxes = boxes.copy()
boxes[:, (0, 2)] /= img_w
boxes[:, (1, 3)] /= img_h
np.testing.assert_array_less(boxes, 1+1e-5)
np.testing.assert_array_less(-boxes, 0+1e-5)
# If calculating the matched loss, replace the sentence with an sentence
# corresponding to other image.
is_matched = 1
sent = datum['sent']
if self.task_matched:
if random.random() < 0.5:
is_matched = 0
other_datum = self.data[random.randint(0, len(self.data)-1)]
while other_datum['img_id'] == img_id:
other_datum = self.data[random.randint(0, len(self.data)-1)]
sent = other_datum['sent']
# Label, convert answer to id
if 'label' in datum:
label = datum['label'].copy()
for ans in list(label.keys()):
label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans)
else:
label = None
# Create target
example = InputExample(
uid, sent, (feats, boxes),
(obj_labels, obj_confs), (attr_labels, attr_confs),
is_matched, label
)
return example
class LXMERTEvaluator:
def __init__(self, dataset: LXMERTDataset):
self.raw_dataset = dataset
# Create QA Eval Data
self.data = []
for datum in self.raw_dataset.data:
sentf = datum['sentf']
for sents_cat, sents in sentf.items():
if sents_cat in datum['labelf']: # A labeled dataset
labels = datum['labelf'][sents_cat]
for sent_idx, sent in enumerate(sents):
new_datum = {
'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
'img_id': datum['img_id'],
'sent': sent,
'dset': sents_cat,
'label': labels[sent_idx]
}
self.data.append(new_datum)
# uid2datum
self.uid2datum = {}
for datum in self.data:
self.uid2datum[datum['uid']] = datum
def evaluate(self, uid2ans: dict, pprint=False):
score = 0.
cnt = 0
dset2score = defaultdict(lambda: 0.)
dset2cnt = defaultdict(lambda: 0)
for uid, ans in uid2ans.items():
if uid not in self.uid2datum: # Not a labeled data
continue
datum = self.uid2datum[uid]
label = datum['label']
dset = datum['dset']
if ans in label:
score += label[ans]
dset2score[dset] += label[ans]
cnt += 1
dset2cnt[dset] += 1
accu = score / cnt
dset2accu = {}
for dset in dset2cnt:
dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
if pprint:
accu_str = "Overall Accu %0.4f, " % (accu)
sorted_keys = sorted(dset2accu.keys())
for key in sorted_keys:
accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
print(accu_str)
return accu, dset2accu
def dump_result(self, uid2ans: dict, path):
raise NotImplemented