# coding=utf-8 # Copyleft 2019 project LXRT. import torch.nn as nn from lxrt.modeling import GeLU, BertLayerNorm from lxrt.entry import LXRTEncoder from param import args class NLVR2Model(nn.Module): def __init__(self): super().__init__() self.lxrt_encoder = LXRTEncoder( args, max_seq_length=20 ) self.hid_dim = hid_dim = self.lxrt_encoder.dim self.logit_fc = nn.Sequential( nn.Linear(hid_dim * 2, hid_dim * 2), GeLU(), BertLayerNorm(hid_dim * 2, eps=1e-12), nn.Linear(hid_dim * 2, 2) ) self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) def forward(self, feat, pos, sent): """ :param feat: b, 2, o, f :param pos: b, 2, o, 4 :param sent: b, (string) :param leng: b, (numpy, int) :return: """ # Pairing images and sentences: # The input of NLVR2 is two images and one sentence. In batch level, they are saved as # [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...] # Here, we flat them to # feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...] # sent = [ sent0, sent0, sent1, sent1, ...] sent = sum(zip(sent, sent), ()) batch_size, img_num, obj_num, feat_size = feat.size() assert img_num == 2 and obj_num == 36 and feat_size == 2048 feat = feat.view(batch_size * 2, obj_num, feat_size) pos = pos.view(batch_size * 2, obj_num, 4) # Extract feature --> Concat x = self.lxrt_encoder(sent, (feat, pos)) x = x.view(-1, self.hid_dim*2) # Compute logit of answers logit = self.logit_fc(x) return logit