liangyupu's picture
Upload 10 files
064752a verified
import os
import torch
from transformers import PreTrainedModel, GenerationConfig, BertLMHeadModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from torch import nn
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
from torch.utils.data import Dataset
from PIL import Image
class MyModel(PreTrainedModel):
def __init__(self, config, trans_model, nougat_model):
super().__init__(config)
self.encoder = nougat_model.encoder
self.decoder = trans_model.decoder
self.project = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=True,
**kwargs,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
encoder_outputs = self.encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoder_hidden_states = encoder_outputs.last_hidden_state
encoder_hidden_states_proj = self.project(encoder_hidden_states)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states_proj,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct_trans = CrossEntropyLoss()
loss_trans = loss_fct_trans(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1).long())
loss = loss_trans
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
)
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=True,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
encoder_outputs = self.encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoder_hidden_states = encoder_outputs.last_hidden_state
encoder_hidden_states_proj = self.project(encoder_hidden_states)
generation_outputs = self.decoder.generate(
encoder_hidden_states=encoder_hidden_states_proj,
generation_config=generation_config,
)
return generation_outputs
class MyDataset(Dataset):
def __init__(self, processor, tokenizer, name_list, max_length, image_dir, text_dir):
self.processor = processor
self.tokenizer = tokenizer
self.name_list = name_list
self.max_length = max_length
self.image_dir = image_dir
self.text_dir = text_dir
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
encoding = {}
image_file_path = os.path.join(self.image_dir, self.name_list[index]+'.png')
image = Image.open(image_file_path)
if image.mode != 'RGB':
image = image.convert('RGB')
pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)
encoding['pixel_values'] = pixel_values
text_file_path = os.path.join(self.text_dir, self.name_list[index]+'.mmd')
with open(text_file_path, 'r') as f:
lines = f.readlines()
text = ''.join(lines)
input_ids = self.tokenizer(text, max_length=self.max_length, truncation=True).input_ids
input_ids = [x for x in input_ids if x != 6]
input_ids = [self.tokenizer.bos_token_id] + input_ids[1:]
decoder_input_ids = input_ids + [self.tokenizer.pad_token_id]*(self.max_length-len(input_ids))
decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long)
labels = input_ids[1:] + [-100]*(self.max_length-len(input_ids)+1)
labels = torch.tensor(labels, dtype=torch.long)
encoding['decoder_input_ids'] = decoder_input_ids
encoding['labels'] = labels
return encoding