|
import torch |
|
import torch.nn as nn |
|
|
|
class Model(nn.Module): |
|
def __init__(self, vit, roberta, tokenizer, device): |
|
super().__init__() |
|
self.bertmap = nn.Conv1d(768, 768, 1) |
|
self.vitmap = nn.Conv1d(768, 768, 1) |
|
self.conv1d = nn.Conv1d(1, 1, 1) |
|
self.add_module("vit", vit) |
|
self.add_module("roberta", roberta) |
|
self.tokenizer = tokenizer |
|
self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]])) |
|
self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.])) |
|
self.device = device |
|
|
|
def forward(self, image, cats): |
|
vit_out = self.vit(image) |
|
vit_out = vit_out[:,1:vit_out.shape[1],:] |
|
vit_out = torch.transpose(vit_out, 2,1) |
|
vit_out = self.vitmap(vit_out) |
|
vit_out = torch.transpose(vit_out, 2,1) |
|
token_out = self.tokenizer.encode_plus( |
|
cats, |
|
padding=True, |
|
add_special_tokens=True, |
|
return_token_type_ids=True, |
|
return_tensors='pt' |
|
).to(self.device) |
|
bert_out = self.roberta(**token_out) |
|
hidden_state = bert_out.last_hidden_state |
|
hidden_state = torch.transpose(hidden_state, 2,1) |
|
hidden_state = self.bertmap(hidden_state) |
|
hidden_state = torch.transpose(hidden_state, 2,1) |
|
pooled_bert_out = hidden_state[:, 0] |
|
pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2) |
|
out = torch.matmul(vit_out, pooled_bert_out) |
|
out = torch.transpose(out, 2,1) |
|
return torch.squeeze(self.conv1d(out), dim=1) |