File size: 6,228 Bytes
064752a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import torch
from transformers import PreTrainedModel, GenerationConfig, BertLMHeadModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from torch import nn
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
from torch.utils.data import Dataset
from PIL import Image
class MyModel(PreTrainedModel):
def __init__(self, config, trans_model, nougat_model):
super().__init__(config)
self.encoder = nougat_model.encoder
self.decoder = trans_model.decoder
self.project = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=True,
**kwargs,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
encoder_outputs = self.encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoder_hidden_states = encoder_outputs.last_hidden_state
encoder_hidden_states_proj = self.project(encoder_hidden_states)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states_proj,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct_trans = CrossEntropyLoss()
loss_trans = loss_fct_trans(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1).long())
loss = loss_trans
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
)
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=True,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
encoder_outputs = self.encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoder_hidden_states = encoder_outputs.last_hidden_state
encoder_hidden_states_proj = self.project(encoder_hidden_states)
generation_outputs = self.decoder.generate(
encoder_hidden_states=encoder_hidden_states_proj,
generation_config=generation_config,
)
return generation_outputs
class MyDataset(Dataset):
def __init__(self, processor, tokenizer, name_list, max_length, image_dir, text_dir):
self.processor = processor
self.tokenizer = tokenizer
self.name_list = name_list
self.max_length = max_length
self.image_dir = image_dir
self.text_dir = text_dir
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
encoding = {}
image_file_path = os.path.join(self.image_dir, self.name_list[index]+'.png')
image = Image.open(image_file_path)
if image.mode != 'RGB':
image = image.convert('RGB')
pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)
encoding['pixel_values'] = pixel_values
text_file_path = os.path.join(self.text_dir, self.name_list[index]+'.mmd')
with open(text_file_path, 'r') as f:
lines = f.readlines()
text = ''.join(lines)
input_ids = self.tokenizer(text, max_length=self.max_length, truncation=True).input_ids
input_ids = [x for x in input_ids if x != 6]
input_ids = [self.tokenizer.bos_token_id] + input_ids[1:]
decoder_input_ids = input_ids + [self.tokenizer.pad_token_id]*(self.max_length-len(input_ids))
decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long)
labels = input_ids[1:] + [-100]*(self.max_length-len(input_ids)+1)
labels = torch.tensor(labels, dtype=torch.long)
encoding['decoder_input_ids'] = decoder_input_ids
encoding['labels'] = labels
return encoding |