| import torch | |
| import torch.nn as nn | |
| class Model(nn.Module): | |
| def __init__(self, vit, roberta, tokenizer, device): | |
| super().__init__() | |
| self.bertmap = nn.Conv1d(768, 768, 1) | |
| self.vitmap = nn.Conv1d(768, 768, 1) | |
| self.conv1d = nn.Conv1d(1, 1, 1) | |
| self.add_module("vit", vit) | |
| self.add_module("roberta", roberta) | |
| self.tokenizer = tokenizer | |
| self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]])) | |
| self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.])) | |
| self.device = device | |
| def forward(self, image, cats): | |
| vit_out = self.vit(image) | |
| vit_out = vit_out[:,1:vit_out.shape[1],:] | |
| vit_out = torch.transpose(vit_out, 2,1) | |
| vit_out = self.vitmap(vit_out) | |
| vit_out = torch.transpose(vit_out, 2,1) | |
| token_out = self.tokenizer.encode_plus( | |
| cats, | |
| padding=True, | |
| add_special_tokens=True, | |
| return_token_type_ids=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| bert_out = self.roberta(**token_out) | |
| hidden_state = bert_out.last_hidden_state | |
| hidden_state = torch.transpose(hidden_state, 2,1) | |
| hidden_state = self.bertmap(hidden_state) | |
| hidden_state = torch.transpose(hidden_state, 2,1) | |
| pooled_bert_out = hidden_state[:, 0] | |
| pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2) | |
| out = torch.matmul(vit_out, pooled_bert_out) | |
| out = torch.transpose(out, 2,1) | |
| return torch.squeeze(self.conv1d(out), dim=1) |