|
import os |
|
import torch |
|
from transformers import PreTrainedModel, GenerationConfig, BertLMHeadModel |
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from typing import Optional, Tuple, Union |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
|
|
class MyModel(PreTrainedModel): |
|
def __init__(self, config, trans_model, nougat_model): |
|
super().__init__(config) |
|
self.encoder = nougat_model.encoder |
|
self.decoder = trans_model.decoder |
|
self.project = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict=True, |
|
**kwargs, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
|
|
|
encoder_outputs = self.encoder( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
encoder_hidden_states = encoder_outputs.last_hidden_state |
|
encoder_hidden_states_proj = self.project(encoder_hidden_states) |
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states_proj, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
|
loss_fct_trans = CrossEntropyLoss() |
|
loss_trans = loss_fct_trans(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1).long()) |
|
|
|
loss = loss_trans |
|
|
|
if not return_dict: |
|
if loss is not None: |
|
return (loss,) + decoder_outputs + encoder_outputs |
|
else: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_hidden_states, |
|
) |
|
|
|
def generate( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict=True, |
|
generation_config: Optional[GenerationConfig] = None, |
|
**kwargs, |
|
): |
|
|
|
encoder_outputs = self.encoder( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
encoder_hidden_states = encoder_outputs.last_hidden_state |
|
encoder_hidden_states_proj = self.project(encoder_hidden_states) |
|
|
|
generation_outputs = self.decoder.generate( |
|
encoder_hidden_states=encoder_hidden_states_proj, |
|
generation_config=generation_config, |
|
) |
|
|
|
return generation_outputs |
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, processor, tokenizer, name_list, max_length, image_dir, text_dir): |
|
self.processor = processor |
|
self.tokenizer = tokenizer |
|
self.name_list = name_list |
|
self.max_length = max_length |
|
self.image_dir = image_dir |
|
self.text_dir = text_dir |
|
|
|
def __len__(self): |
|
return len(self.name_list) |
|
|
|
def __getitem__(self, index): |
|
encoding = {} |
|
image_file_path = os.path.join(self.image_dir, self.name_list[index]+'.png') |
|
image = Image.open(image_file_path) |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0) |
|
encoding['pixel_values'] = pixel_values |
|
|
|
text_file_path = os.path.join(self.text_dir, self.name_list[index]+'.mmd') |
|
with open(text_file_path, 'r') as f: |
|
lines = f.readlines() |
|
text = ''.join(lines) |
|
input_ids = self.tokenizer(text, max_length=self.max_length, truncation=True).input_ids |
|
input_ids = [x for x in input_ids if x != 6] |
|
input_ids = [self.tokenizer.bos_token_id] + input_ids[1:] |
|
|
|
decoder_input_ids = input_ids + [self.tokenizer.pad_token_id]*(self.max_length-len(input_ids)) |
|
decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) |
|
labels = input_ids[1:] + [-100]*(self.max_length-len(input_ids)+1) |
|
labels = torch.tensor(labels, dtype=torch.long) |
|
encoding['decoder_input_ids'] = decoder_input_ids |
|
encoding['labels'] = labels |
|
|
|
return encoding |