|
import os |
|
import json |
|
import torch |
|
from transformers import DonutProcessor, AutoTokenizer |
|
import argparse |
|
from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig |
|
from my_model import MyModel, MyDataset |
|
from transformers import Trainer, TrainingArguments |
|
|
|
def train(args): |
|
processor = DonutProcessor.from_pretrained(args.donut_dir) |
|
processor.image_processor.size = {'height': 896, 'width': 672} |
|
processor.image_processor.image_mean = [0.485, 0.456, 0.406] |
|
processor.image_processor.image_std = [0.229, 0.224, 0.225] |
|
tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer')) |
|
|
|
image_dir = os.path.join(args.dataset_dir, 'imgs') |
|
text_dir = os.path.join(args.dataset_dir, 'zh_mmd') |
|
|
|
json_file_path = os.path.join(args.dataset_dir, 'split_dataset.json') |
|
with open(json_file_path, 'r') as f: |
|
json_dict = json.load(f) |
|
train_name_list = json_dict['train_name_list'] |
|
valid_name_list = json_dict['valid_name_list'] |
|
|
|
train_dataset = MyDataset(processor, tokenizer, train_name_list, args.max_length, image_dir, text_dir) |
|
valid_dataset = MyDataset(processor, tokenizer, valid_name_list, args.max_length, image_dir, text_dir) |
|
|
|
encoder_config = BertConfig() |
|
decoder_config = BertConfig() |
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) |
|
encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id |
|
encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id |
|
encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id |
|
encoder_decoder_config.decoder.hidden_size = 512 |
|
encoder_decoder_config.decoder.intermediate_size = 2048 |
|
encoder_decoder_config.decoder.max_length = args.max_length |
|
encoder_decoder_config.decoder.max_position_embeddings = args.max_length |
|
encoder_decoder_config.decoder.num_attention_heads = 8 |
|
encoder_decoder_config.decoder.num_hidden_layers = 6 |
|
encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id |
|
encoder_decoder_config.decoder.type_vocab_size = 1 |
|
encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size |
|
|
|
trans_model = EncoderDecoderModel(config=encoder_decoder_config) |
|
nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir) |
|
|
|
model = MyModel(nougat_model.config, trans_model, nougat_model) |
|
|
|
num_gpu = torch.cuda.device_count() |
|
gradient_accumulation_steps = args.batch_size // (num_gpu * args.batch_size_per_gpu) |
|
|
|
training_args = TrainingArguments( |
|
output_dir=os.path.join(args.base_dir, 'models'), |
|
per_device_train_batch_size=args.batch_size_per_gpu, |
|
per_device_eval_batch_size=args.batch_size_per_gpu, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
logging_strategy='steps', |
|
logging_steps=1, |
|
evaluation_strategy='steps', |
|
eval_steps=args.eval_steps, |
|
save_strategy='steps', |
|
save_steps=args.save_steps, |
|
fp16=args.fp16, |
|
learning_rate=args.learning_rate, |
|
max_steps=args.max_steps, |
|
warmup_steps=args.warmup_steps, |
|
dataloader_num_workers=args.dataloader_num_workers, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=valid_dataset, |
|
) |
|
|
|
trainer.train() |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--base_dir", type=str) |
|
parser.add_argument("--dataset_dir", type=str) |
|
parser.add_argument("--donut_dir", type=str) |
|
parser.add_argument("--nougat_dir", type=str) |
|
|
|
parser.add_argument("--max_length", type=int, default=1536) |
|
parser.add_argument("--batch_size", type=int, default=64) |
|
parser.add_argument("--batch_size_per_gpu", type=int, default=4) |
|
parser.add_argument("--eval_steps", type=int, default=1000) |
|
parser.add_argument("--save_steps", type=int, default=1000) |
|
parser.add_argument("--fp16", type=bool, default=True) |
|
parser.add_argument("--learning_rate", type=float, default=5e-5) |
|
parser.add_argument("--max_steps", type=int, default=10000) |
|
parser.add_argument("--warmup_steps", type=int, default=1000) |
|
parser.add_argument("--dataloader_num_workers", type=int, default=8) |
|
|
|
args = parser.parse_args() |
|
|
|
train(args) |