File size: 4,419 Bytes
064752a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import json
import torch
from transformers import DonutProcessor, AutoTokenizer
import argparse
from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig
from my_model import MyModel, MyDataset
from transformers import Trainer, TrainingArguments
def train(args):
processor = DonutProcessor.from_pretrained(args.donut_dir)
processor.image_processor.size = {'height': 896, 'width': 672}
processor.image_processor.image_mean = [0.485, 0.456, 0.406]
processor.image_processor.image_std = [0.229, 0.224, 0.225]
tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer'))
image_dir = os.path.join(args.dataset_dir, 'imgs')
text_dir = os.path.join(args.dataset_dir, 'zh_mmd')
json_file_path = os.path.join(args.dataset_dir, 'split_dataset.json')
with open(json_file_path, 'r') as f:
json_dict = json.load(f)
train_name_list = json_dict['train_name_list']
valid_name_list = json_dict['valid_name_list']
train_dataset = MyDataset(processor, tokenizer, train_name_list, args.max_length, image_dir, text_dir)
valid_dataset = MyDataset(processor, tokenizer, valid_name_list, args.max_length, image_dir, text_dir)
encoder_config = BertConfig()
decoder_config = BertConfig()
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id
encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id
encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id
encoder_decoder_config.decoder.hidden_size = 512
encoder_decoder_config.decoder.intermediate_size = 2048
encoder_decoder_config.decoder.max_length = args.max_length
encoder_decoder_config.decoder.max_position_embeddings = args.max_length
encoder_decoder_config.decoder.num_attention_heads = 8
encoder_decoder_config.decoder.num_hidden_layers = 6
encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id
encoder_decoder_config.decoder.type_vocab_size = 1
encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size
trans_model = EncoderDecoderModel(config=encoder_decoder_config)
nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir)
model = MyModel(nougat_model.config, trans_model, nougat_model)
num_gpu = torch.cuda.device_count()
gradient_accumulation_steps = args.batch_size // (num_gpu * args.batch_size_per_gpu)
training_args = TrainingArguments(
output_dir=os.path.join(args.base_dir, 'models'),
per_device_train_batch_size=args.batch_size_per_gpu,
per_device_eval_batch_size=args.batch_size_per_gpu,
gradient_accumulation_steps=gradient_accumulation_steps,
logging_strategy='steps',
logging_steps=1,
evaluation_strategy='steps',
eval_steps=args.eval_steps,
save_strategy='steps',
save_steps=args.save_steps,
fp16=args.fp16,
learning_rate=args.learning_rate,
max_steps=args.max_steps,
warmup_steps=args.warmup_steps,
dataloader_num_workers=args.dataloader_num_workers,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--base_dir", type=str)
parser.add_argument("--dataset_dir", type=str)
parser.add_argument("--donut_dir", type=str)
parser.add_argument("--nougat_dir", type=str)
parser.add_argument("--max_length", type=int, default=1536)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--batch_size_per_gpu", type=int, default=4)
parser.add_argument("--eval_steps", type=int, default=1000)
parser.add_argument("--save_steps", type=int, default=1000)
parser.add_argument("--fp16", type=bool, default=True)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--max_steps", type=int, default=10000)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--dataloader_num_workers", type=int, default=8)
args = parser.parse_args()
train(args) |