T-BOD / model.py
fmajer's picture
css change
a634e56
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, vit, roberta, tokenizer, device):
super().__init__()
self.bertmap = nn.Conv1d(768, 768, 1)
self.vitmap = nn.Conv1d(768, 768, 1)
self.conv1d = nn.Conv1d(1, 1, 1)
self.add_module("vit", vit)
self.add_module("roberta", roberta)
self.tokenizer = tokenizer
self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]]))
self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.]))
self.device = device
def forward(self, image, cats):
vit_out = self.vit(image)
vit_out = vit_out[:,1:vit_out.shape[1],:]
vit_out = torch.transpose(vit_out, 2,1)
vit_out = self.vitmap(vit_out)
vit_out = torch.transpose(vit_out, 2,1)
token_out = self.tokenizer.encode_plus(
cats,
padding=True,
add_special_tokens=True,
return_token_type_ids=True,
return_tensors='pt'
).to(self.device)
bert_out = self.roberta(**token_out)
hidden_state = bert_out.last_hidden_state
hidden_state = torch.transpose(hidden_state, 2,1)
hidden_state = self.bertmap(hidden_state)
hidden_state = torch.transpose(hidden_state, 2,1)
pooled_bert_out = hidden_state[:, 0]
pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2)
out = torch.matmul(vit_out, pooled_bert_out)
out = torch.transpose(out, 2,1)
return torch.squeeze(self.conv1d(out), dim=1)