Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyleft 2019 project LXRT. | |
from collections import defaultdict | |
import json | |
import random | |
import numpy as np | |
from torch.utils.data import Dataset | |
from param import args | |
from pretrain.qa_answer_table import AnswerTable | |
from utils import load_obj_tsv | |
TINY_IMG_NUM = 500 | |
FAST_IMG_NUM = 5000 | |
Split2ImgFeatPath = { | |
'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv', | |
'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv', | |
'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv', | |
'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv', | |
} | |
class InputExample(object): | |
"""A single training/test example for the language model.""" | |
def __init__(self, uid, sent, visual_feats=None, | |
obj_labels=None, attr_labels=None, | |
is_matched=None, label=None): | |
self.uid = uid | |
self.sent = sent | |
self.visual_feats = visual_feats | |
self.obj_labels = obj_labels | |
self.attr_labels = attr_labels | |
self.is_matched = is_matched # whether the visual and obj matched | |
self.label = label | |
class LXMERTDataset: | |
def __init__(self, splits: str, qa_sets=None): | |
""" | |
:param splits: The data sources to be loaded | |
:param qa_sets: if None, no action | |
o.w., only takes the answers appearing in these dsets | |
and remove all unlabeled data (MSCOCO captions) | |
""" | |
self.name = splits | |
self.sources = splits.split(',') | |
# Loading datasets to data | |
self.data = [] | |
for source in self.sources: | |
self.data.extend(json.load(open("data/lxmert/%s.json" % source))) | |
print("Load %d data from %s" % (len(self.data), self.name)) | |
# Create answer table according to the qa_sets | |
self.answer_table = AnswerTable(qa_sets) | |
print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map()))) | |
# Modify the answers | |
for datum in self.data: | |
labelf = datum['labelf'] | |
for cat, labels in labelf.items(): | |
for label in labels: | |
for ans in list(label.keys()): | |
new_ans = self.answer_table.convert_ans(ans) | |
if self.answer_table.used(new_ans): | |
if ans != new_ans: | |
label[new_ans] = label.pop(ans) | |
else: | |
label.pop(ans) | |
def __len__(self): | |
return len(self.data) | |
def make_uid(img_id, dset, sent_idx): | |
return "%s_%s_%03d" % (img_id, dset, sent_idx), | |
""" | |
Example in obj tsv: | |
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", | |
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] | |
""" | |
class LXMERTTorchDataset(Dataset): | |
def __init__(self, dataset: LXMERTDataset, topk=-1): | |
super().__init__() | |
self.raw_dataset = dataset | |
self.task_matched = args.task_matched | |
if args.tiny: | |
topk = TINY_IMG_NUM | |
elif args.fast: | |
topk = FAST_IMG_NUM | |
# Load the dataset | |
img_data = [] | |
for source in self.raw_dataset.sources: | |
img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk)) | |
self.imgid2img = {} | |
for img_datum in img_data: | |
self.imgid2img[img_datum['img_id']] = img_datum | |
# Filter out the dataset | |
used_data = [] | |
for datum in self.raw_dataset.data: | |
if datum['img_id'] in self.imgid2img: | |
used_data.append(datum) | |
# Flatten the dataset (into one sent + one image entries) | |
self.data = [] | |
for datum in used_data: | |
sentf = datum['sentf'] | |
for sents_cat, sents in sentf.items(): | |
if sents_cat in datum['labelf']: | |
labels = datum['labelf'][sents_cat] | |
else: | |
labels = None | |
for sent_idx, sent in enumerate(sents): | |
new_datum = { | |
'uid': make_uid(datum['img_id'], sents_cat, sent_idx), | |
'img_id': datum['img_id'], | |
'sent': sent | |
} | |
if labels is not None: | |
new_datum['label'] = labels[sent_idx] | |
self.data.append(new_datum) | |
print("Use %d data in torch dataset" % (len(self.data))) | |
def __len__(self): | |
return len(self.data) | |
def random_feat(self): | |
"""Get a random obj feat from the dataset.""" | |
datum = self.data[random.randint(0, len(self.data)-1)] | |
img_id = datum['img_id'] | |
img_info = self.imgid2img[img_id] | |
feat = img_info['features'][random.randint(0, 35)] | |
return feat | |
def __getitem__(self, item: int): | |
datum = self.data[item] | |
uid = datum['uid'] | |
img_id = datum['img_id'] | |
# Get image info | |
img_info = self.imgid2img[img_id] | |
obj_num = img_info['num_boxes'] | |
feats = img_info['features'].copy() | |
boxes = img_info['boxes'].copy() | |
obj_labels = img_info['objects_id'].copy() | |
obj_confs = img_info['objects_conf'].copy() | |
attr_labels = img_info['attrs_id'].copy() | |
attr_confs = img_info['attrs_conf'].copy() | |
assert obj_num == len(boxes) == len(feats) | |
# Normalize the boxes (to 0 ~ 1) | |
img_h, img_w = img_info['img_h'], img_info['img_w'] | |
boxes = boxes.copy() | |
boxes[:, (0, 2)] /= img_w | |
boxes[:, (1, 3)] /= img_h | |
np.testing.assert_array_less(boxes, 1+1e-5) | |
np.testing.assert_array_less(-boxes, 0+1e-5) | |
# If calculating the matched loss, replace the sentence with an sentence | |
# corresponding to other image. | |
is_matched = 1 | |
sent = datum['sent'] | |
if self.task_matched: | |
if random.random() < 0.5: | |
is_matched = 0 | |
other_datum = self.data[random.randint(0, len(self.data)-1)] | |
while other_datum['img_id'] == img_id: | |
other_datum = self.data[random.randint(0, len(self.data)-1)] | |
sent = other_datum['sent'] | |
# Label, convert answer to id | |
if 'label' in datum: | |
label = datum['label'].copy() | |
for ans in list(label.keys()): | |
label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans) | |
else: | |
label = None | |
# Create target | |
example = InputExample( | |
uid, sent, (feats, boxes), | |
(obj_labels, obj_confs), (attr_labels, attr_confs), | |
is_matched, label | |
) | |
return example | |
class LXMERTEvaluator: | |
def __init__(self, dataset: LXMERTDataset): | |
self.raw_dataset = dataset | |
# Create QA Eval Data | |
self.data = [] | |
for datum in self.raw_dataset.data: | |
sentf = datum['sentf'] | |
for sents_cat, sents in sentf.items(): | |
if sents_cat in datum['labelf']: # A labeled dataset | |
labels = datum['labelf'][sents_cat] | |
for sent_idx, sent in enumerate(sents): | |
new_datum = { | |
'uid': make_uid(datum['img_id'], sents_cat, sent_idx), | |
'img_id': datum['img_id'], | |
'sent': sent, | |
'dset': sents_cat, | |
'label': labels[sent_idx] | |
} | |
self.data.append(new_datum) | |
# uid2datum | |
self.uid2datum = {} | |
for datum in self.data: | |
self.uid2datum[datum['uid']] = datum | |
def evaluate(self, uid2ans: dict, pprint=False): | |
score = 0. | |
cnt = 0 | |
dset2score = defaultdict(lambda: 0.) | |
dset2cnt = defaultdict(lambda: 0) | |
for uid, ans in uid2ans.items(): | |
if uid not in self.uid2datum: # Not a labeled data | |
continue | |
datum = self.uid2datum[uid] | |
label = datum['label'] | |
dset = datum['dset'] | |
if ans in label: | |
score += label[ans] | |
dset2score[dset] += label[ans] | |
cnt += 1 | |
dset2cnt[dset] += 1 | |
accu = score / cnt | |
dset2accu = {} | |
for dset in dset2cnt: | |
dset2accu[dset] = dset2score[dset] / dset2cnt[dset] | |
if pprint: | |
accu_str = "Overall Accu %0.4f, " % (accu) | |
sorted_keys = sorted(dset2accu.keys()) | |
for key in sorted_keys: | |
accu_str += "%s Accu %0.4f, " % (key, dset2accu[key]) | |
print(accu_str) | |
return accu, dset2accu | |
def dump_result(self, uid2ans: dict, path): | |
raise NotImplemented | |