|
import torch.nn as nn |
|
from transformers import VisionEncoderDecoderModel, PreTrainedModel, AutoConfig |
|
|
|
class LegibilityModel(PreTrainedModel): |
|
def __init__(self, config): |
|
config = AutoConfig.from_pretrained("microsoft/trocr-base-handwritten") |
|
super(LegibilityModel, self).__init__(config=config) |
|
|
|
|
|
self.model = VisionEncoderDecoderModel(config).encoder |
|
|
|
|
|
self.stack = nn.Sequential( |
|
nn.Dropout(0), |
|
nn.Linear(768, 768), |
|
nn.ReLU(), |
|
nn.Dropout(0), |
|
nn.Linear(768, 1) |
|
) |
|
|
|
|
|
def forward(self, img_batch, choice=None, img0=None, img1=None): |
|
output = self.model(img_batch) |
|
|
|
output = output.last_hidden_state.mean(dim=1) |
|
scores = self.stack(output) |
|
return scores.squeeze() |