import os import torch from transformers import PreTrainedModel, GenerationConfig, BertLMHeadModel from transformers.modeling_outputs import Seq2SeqLMOutput from torch import nn from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union from torch.utils.data import Dataset from PIL import Image class MyModel(PreTrainedModel): def __init__(self, config, trans_model, nougat_model): super().__init__(config) self.encoder = nougat_model.encoder self.decoder = trans_model.decoder self.project = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict=True, **kwargs, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: encoder_outputs = self.encoder( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) encoder_hidden_states = encoder_outputs.last_hidden_state encoder_hidden_states_proj = self.project(encoder_hidden_states) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states_proj, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, ) # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[0] loss_fct_trans = CrossEntropyLoss() loss_trans = loss_fct_trans(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1).long()) loss = loss_trans if not return_dict: if loss is not None: return (loss,) + decoder_outputs + encoder_outputs else: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_hidden_states, ) def generate( self, pixel_values: Optional[torch.FloatTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict=True, generation_config: Optional[GenerationConfig] = None, **kwargs, ): encoder_outputs = self.encoder( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) encoder_hidden_states = encoder_outputs.last_hidden_state encoder_hidden_states_proj = self.project(encoder_hidden_states) generation_outputs = self.decoder.generate( encoder_hidden_states=encoder_hidden_states_proj, generation_config=generation_config, ) return generation_outputs class MyDataset(Dataset): def __init__(self, processor, tokenizer, name_list, max_length, image_dir, text_dir): self.processor = processor self.tokenizer = tokenizer self.name_list = name_list self.max_length = max_length self.image_dir = image_dir self.text_dir = text_dir def __len__(self): return len(self.name_list) def __getitem__(self, index): encoding = {} image_file_path = os.path.join(self.image_dir, self.name_list[index]+'.png') image = Image.open(image_file_path) if image.mode != 'RGB': image = image.convert('RGB') pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0) encoding['pixel_values'] = pixel_values text_file_path = os.path.join(self.text_dir, self.name_list[index]+'.mmd') with open(text_file_path, 'r') as f: lines = f.readlines() text = ''.join(lines) input_ids = self.tokenizer(text, max_length=self.max_length, truncation=True).input_ids input_ids = [x for x in input_ids if x != 6] input_ids = [self.tokenizer.bos_token_id] + input_ids[1:] decoder_input_ids = input_ids + [self.tokenizer.pad_token_id]*(self.max_length-len(input_ids)) decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) labels = input_ids[1:] + [-100]*(self.max_length-len(input_ids)+1) labels = torch.tensor(labels, dtype=torch.long) encoding['decoder_input_ids'] = decoder_input_ids encoding['labels'] = labels return encoding