# coding=utf-8 # Copyleft 2019 project LXRT. import torch.nn as nn from ..param import args from ..lxrt.entry import LXRTEncoder from ..lxrt.modeling import BertLayerNorm, GeLU from transformers import AutoTokenizer, AutoModelForQuestionAnswering # Max length including and MAX_VQA_LENGTH = 20 class VQAModel(nn.Module): def __init__(self, num_answers): super().__init__() # # Build LXRT encoder # self.lxrt_encoder = LXRTEncoder( # args, # max_seq_length=MAX_VQA_LENGTH # ) # hid_dim = self.lxrt_encoder.dim # # # VQA Answer heads # self.logit_fc = nn.Sequential( # nn.Linear(hid_dim, hid_dim * 2), # GeLU(), # BertLayerNorm(hid_dim * 2, eps=1e-12), # nn.Linear(hid_dim * 2, num_answers) # ) # self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) self.tokenizer = AutoTokenizer.from_pretrained("unc-nlp/lxmert-vqa-uncased") self.model = AutoModelForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased") def forward(self, feat, pos, sent): """ b -- batch_size, o -- object_number, f -- visual_feature_size :param feat: (b, o, f) :param pos: (b, o, 4) :param sent: (b,) Type -- list of string :param leng: (b,) Type -- int numpy array :return: (b, num_answer) The logit of each answers. """ return self.model(sent, feat, pos)