# coding=utf-8 # Copyleft 2019 project LXRT. import collections import os import random from tqdm import tqdm import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from param import args from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator from lxrt.entry import set_visual_config from lxrt.tokenization import BertTokenizer from lxrt.modeling import LXRTPretraining DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator') def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple: # Decide which QA datasets would be used in pre-training. # Options: vqa, gqa, visual7w # Note: visual7w is a part of vgqa, we take the name here. qa_sets = args.qa_sets if qa_sets is not None: qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) # Build dataset, data loader, and evaluator. dset = LXMERTDataset(splits, qa_sets=qa_sets) tset = LXMERTTorchDataset(dset, topk) data_loader = DataLoader( tset, batch_size=bs, shuffle=shuffle, num_workers=args.num_workers, collate_fn=lambda x: x, drop_last=drop_last, pin_memory=True ) evaluator = LXMERTEvaluator(dset) print() return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator) train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True) valid_batch_size = 2048 if args.multiGPU else 512 valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000) class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, lm_label_ids, visual_feats, obj_labels, is_matched, ans): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.lm_label_ids = lm_label_ids self.visual_feats = visual_feats self.obj_labels = obj_labels self.is_matched = is_matched self.ans = ans def random_word(tokens, tokenizer): """ Masking some random tokens for Language Model task with probabilities as in the original BERT paper. :param tokens: list of str, tokenized sentence. :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) :return: (list of str, list of int), masked tokens and related labels for LM prediction """ output_label = [] for i, token in enumerate(tokens): prob = random.random() # mask token with probability ratio = args.word_mask_rate if prob < ratio: prob /= ratio # 80% randomly change token to mask token if prob < 0.8: tokens[i] = "[MASK]" # 10% randomly change token to random token elif prob < 0.9: tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] # -> rest 10% randomly keep current token # append current token to output (we will predict these later) try: output_label.append(tokenizer.vocab[token]) except KeyError: # For unknown words (should not occur with BPE vocab) output_label.append(tokenizer.vocab["[UNK]"]) else: # no masking token (will be ignored by loss function later) output_label.append(-1) return tokens, output_label def random_feat(feats): mask_feats = feats.copy() feat_mask = np.zeros(len(feats), dtype=np.float32) for i in range(len(feats)): prob = random.random() # mask token with probability if prob < args.obj_mask_rate: prob /= args.obj_mask_rate # 80% randomly change token to zero feat if prob < 0.8: mask_feats[i, :] = 0. # 10% randomly change token to random feat elif prob < 0.9: mask_feats[i, :] = train_tuple.torchdset.random_feat() # -> rest 10% randomly keep current feat # Need to predict this feat feat_mask[i] = 1. return mask_feats, feat_mask def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures: """ Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with IDs, LM labels, input_mask, CLS and SEP tokens etc. :param example: InputExample, containing sentence input as strings and is_next label :param max_seq_length: int, maximum length of sequence. :param tokenizer: Tokenizer :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) """ tokens = tokenizer.tokenize(example.sent.strip()) # Account for [CLS] and [SEP] with "- 2" if len(tokens) > max_seq_length - 2: tokens = tokens[:(max_seq_length - 2)] # Ge random words masked_tokens, masked_label = random_word(tokens, tokenizer) # concatenate lm labels and account for CLS, SEP, SEP masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]'] input_ids = tokenizer.convert_tokens_to_ids(masked_tokens) # Mask & Segment Word lm_label_ids = ([-1] + masked_label + [-1]) input_mask = [1] * len(input_ids) segment_ids = [0] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) lm_label_ids.append(-1) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(lm_label_ids) == max_seq_length feat, boxes = example.visual_feats obj_labels, obj_confs = example.obj_labels attr_labels, attr_confs = example.attr_labels # Mask Image Features: masked_feat, feat_mask = random_feat(feat) # QA answer label if example.label is None or len(example.label) == 0 or example.is_matched != 1: # 1. No label 2. Label is pruned 3. unmatched visual + language pair ans = -1 else: keys, values = zip(*example.label.items()) if len(keys) == 1: ans = keys[0] else: value_sum = sum(values) prob = [value / value_sum for value in values] choice = np.random.multinomial(1, prob).argmax() ans = keys[choice] features = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, lm_label_ids=lm_label_ids, visual_feats=(masked_feat, boxes), obj_labels={ 'obj': (obj_labels, obj_confs), 'attr': (attr_labels, attr_confs), 'feat': (feat, feat_mask), }, is_matched=example.is_matched, ans=ans, ) return features LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA') class LXMERT: def __init__(self, max_seq_length): super().__init__() self.max_seq_length = max_seq_length self.tokenizer = BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True ) # Build model set_visual_config(args) self.model = LXRTPretraining.from_pretrained( "bert-base-uncased", task_mask_lm=args.task_mask_lm, task_obj_predict=args.task_obj_predict, task_matched=args.task_matched, task_qa=args.task_qa, visual_losses=args.visual_losses, num_answers=train_tuple.dataset.answer_table.num_answers ) # Weight initialization and loading if args.from_scratch: print("Train from Scratch: re-initialize all BERT weights.") self.model.apply(self.model.init_bert_weights) if args.load is not None: self.load(args.load) if args.load_lxmert is not None: # Load lxmert would not load the answer head. self.load_lxmert(args.load_lxmert) # GPU Options self.model = self.model.cuda() if args.multiGPU: self.model = nn.DataParallel(self.model) def forward(self, examples): train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer) for example in examples] # language Inputs input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda() input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda() segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda() # Visual Inputs feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda() pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda() # Language Prediction lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda() # Visual Prediction obj_labels = {} for key in ('obj', 'attr', 'feat'): visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda() visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda() assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1) obj_labels[key] = (visn_labels, visn_mask) # Joint Prediction matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda() ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda() """ forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None): """ loss, losses, ans_logit = self.model( input_ids, segment_ids, input_mask, lm_labels, feats, pos, obj_labels, matched_labels, ans ) return loss, losses.detach().cpu(), ans_logit def train_batch(self, optim, batch): optim.zero_grad() loss, losses, ans_logit = self.forward(batch) if args.multiGPU: loss = loss.mean() losses = losses.mean(0) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 1.) optim.step() return loss.item(), losses.cpu().numpy(), ans_logit def valid_batch(self, batch): with torch.no_grad(): loss, losses, ans_logit = self.forward(batch) if args.multiGPU: loss = loss.mean() losses = losses.mean(0) return loss.item(), losses.cpu().numpy(), ans_logit def train(self, train_tuple: DataTuple, eval_tuple: DataTuple): train_ld = train_tuple.loader # Optimizer from lxrt.optimization import BertAdam batch_per_epoch = len(train_ld) t_total = int(batch_per_epoch * args.epochs) warmup_ratio = 0.05 warmup_iters = int(t_total * warmup_ratio) print("Batch per epoch: %d" % batch_per_epoch) print("Total Iters: %d" % t_total) print("Warm up Iters: %d" % warmup_iters) optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total) # Train best_eval_loss = 9595. for epoch in range(args.epochs): # Train self.model.train() total_loss = 0. total_losses = 0. uid2ans = {} for batch in tqdm(train_ld, total=len(train_ld)): loss, losses, logit = self.train_batch(optim, batch) total_loss += loss total_losses += losses if args.task_qa: score, label = logit.max(1) for datum, l in zip(batch, label.cpu().numpy()): uid = datum.uid ans = train_tuple.dataset.answer_table.id2ans(l) uid2ans[uid] = ans print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)) losses_str = "The losses are " for name, loss in zip(LOSSES_NAME, total_losses): losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch) print(losses_str) if args.task_qa: train_tuple.evaluator.evaluate(uid2ans, pprint=True) # Eval avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1) # Save if avg_eval_loss < best_eval_loss: best_eval_loss = avg_eval_loss self.save("BEST_EVAL_LOSS") self.save("Epoch%02d" % (epoch+1)) def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1): self.model.eval() eval_ld = eval_tuple.loader total_loss = 0. total_losses = 0. uid2ans = {} for i, batch in enumerate(eval_ld): loss, losses, logit = self.valid_batch(batch) total_loss += loss total_losses += losses if args.task_qa: score, label = logit.max(1) for datum, l in zip(batch, label.cpu().numpy()): uid = datum.uid ans = train_tuple.dataset.answer_table.id2ans(l) uid2ans[uid] = ans if i == iters: break print("The valid loss is %0.4f" % (total_loss / len(eval_ld))) losses_str = "The losses are " for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)): losses_str += "%s: %0.4f " % (name, loss) print(losses_str) if args.task_qa: eval_tuple.evaluator.evaluate(uid2ans, pprint=True) return total_loss / len(eval_ld) def save(self, name): torch.save(self.model.state_dict(), os.path.join(args.output, "%s_LXRT.pth" % name)) def load(self, path): print("Load BERT extractor from %s" % path) state_dict = torch.load("%s_LXRT.pth" % path) self.model.load_state_dict(state_dict) def load_lxmert(self, path): print("Load lxmert model from %s" % path) state_dict = torch.load("%s_LXRT.pth" % path) # Do not load any answer head for key in list(state_dict.keys()): if 'answer' in key: state_dict.pop(key) # Change Multi GPU to single GPU new_state_dict = {} for key, value in state_dict.items(): if key.startswith("module."): new_state_dict[key[len("module."):]] = value state_dict = new_state_dict load_keys = set(state_dict.keys()) model_keys = set(self.model.state_dict().keys()) print() print("Keys in loaded but not in model:") for key in sorted(load_keys.difference(model_keys)): print(key) print() print("Keys in model but not in loaded:") for key in sorted(model_keys.difference(load_keys)): print(key) print() self.model.load_state_dict(state_dict, strict=False) if __name__ == "__main__": lxmert = LXMERT(max_seq_length=20) lxmert.train(train_tuple, valid_tuple)