import torch.nn as nn from transformers import VisionEncoderDecoderModel, PreTrainedModel, AutoConfig class LegibilityModel(PreTrainedModel): def __init__(self, config): config = AutoConfig.from_pretrained("microsoft/trocr-base-handwritten") super(LegibilityModel, self).__init__(config=config) # base model architecture self.model = VisionEncoderDecoderModel(config).encoder # change dropout during training self.stack = nn.Sequential( nn.Dropout(0), nn.Linear(768, 768), nn.ReLU(), nn.Dropout(0), nn.Linear(768, 1) ) # choice, img0, img1 are not used by the model, but are passed by the trainer def forward(self, img_batch, choice=None, img0=None, img1=None): output = self.model(img_batch) # average the output of the last hidden layer output = output.last_hidden_state.mean(dim=1) scores = self.stack(output) return scores.squeeze()