import os import json import torch from transformers import DonutProcessor, AutoTokenizer import argparse from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig from my_model import MyModel, MyDataset from transformers import Trainer, TrainingArguments def train(args): processor = DonutProcessor.from_pretrained(args.donut_dir) processor.image_processor.size = {'height': 896, 'width': 672} processor.image_processor.image_mean = [0.485, 0.456, 0.406] processor.image_processor.image_std = [0.229, 0.224, 0.225] tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer')) image_dir = os.path.join(args.dataset_dir, 'imgs') text_dir = os.path.join(args.dataset_dir, 'zh_mmd') json_file_path = os.path.join(args.dataset_dir, 'split_dataset.json') with open(json_file_path, 'r') as f: json_dict = json.load(f) train_name_list = json_dict['train_name_list'] valid_name_list = json_dict['valid_name_list'] train_dataset = MyDataset(processor, tokenizer, train_name_list, args.max_length, image_dir, text_dir) valid_dataset = MyDataset(processor, tokenizer, valid_name_list, args.max_length, image_dir, text_dir) encoder_config = BertConfig() decoder_config = BertConfig() encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id encoder_decoder_config.decoder.hidden_size = 512 encoder_decoder_config.decoder.intermediate_size = 2048 encoder_decoder_config.decoder.max_length = args.max_length encoder_decoder_config.decoder.max_position_embeddings = args.max_length encoder_decoder_config.decoder.num_attention_heads = 8 encoder_decoder_config.decoder.num_hidden_layers = 6 encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id encoder_decoder_config.decoder.type_vocab_size = 1 encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size trans_model = EncoderDecoderModel(config=encoder_decoder_config) nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir) model = MyModel(nougat_model.config, trans_model, nougat_model) num_gpu = torch.cuda.device_count() gradient_accumulation_steps = args.batch_size // (num_gpu * args.batch_size_per_gpu) training_args = TrainingArguments( output_dir=os.path.join(args.base_dir, 'models'), per_device_train_batch_size=args.batch_size_per_gpu, per_device_eval_batch_size=args.batch_size_per_gpu, gradient_accumulation_steps=gradient_accumulation_steps, logging_strategy='steps', logging_steps=1, evaluation_strategy='steps', eval_steps=args.eval_steps, save_strategy='steps', save_steps=args.save_steps, fp16=args.fp16, learning_rate=args.learning_rate, max_steps=args.max_steps, warmup_steps=args.warmup_steps, dataloader_num_workers=args.dataloader_num_workers, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=valid_dataset, ) trainer.train() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--base_dir", type=str) parser.add_argument("--dataset_dir", type=str) parser.add_argument("--donut_dir", type=str) parser.add_argument("--nougat_dir", type=str) parser.add_argument("--max_length", type=int, default=1536) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--batch_size_per_gpu", type=int, default=4) parser.add_argument("--eval_steps", type=int, default=1000) parser.add_argument("--save_steps", type=int, default=1000) parser.add_argument("--fp16", type=bool, default=True) parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--max_steps", type=int, default=10000) parser.add_argument("--warmup_steps", type=int, default=1000) parser.add_argument("--dataloader_num_workers", type=int, default=8) args = parser.parse_args() train(args)