File size: 4,419 Bytes
064752a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import json
import torch
from transformers import DonutProcessor, AutoTokenizer
import argparse
from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig
from my_model import MyModel, MyDataset
from transformers import Trainer, TrainingArguments

def train(args):
    processor = DonutProcessor.from_pretrained(args.donut_dir)
    processor.image_processor.size = {'height': 896, 'width': 672}
    processor.image_processor.image_mean = [0.485, 0.456, 0.406]
    processor.image_processor.image_std = [0.229, 0.224, 0.225]
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer'))

    image_dir = os.path.join(args.dataset_dir, 'imgs')
    text_dir = os.path.join(args.dataset_dir, 'zh_mmd')

    json_file_path = os.path.join(args.dataset_dir, 'split_dataset.json')
    with open(json_file_path, 'r') as f:
        json_dict = json.load(f)
    train_name_list = json_dict['train_name_list']
    valid_name_list = json_dict['valid_name_list']

    train_dataset = MyDataset(processor, tokenizer, train_name_list, args.max_length, image_dir, text_dir)
    valid_dataset = MyDataset(processor, tokenizer, valid_name_list, args.max_length, image_dir, text_dir)

    encoder_config = BertConfig()
    decoder_config = BertConfig()
    encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
    encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id
    encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id
    encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id
    encoder_decoder_config.decoder.hidden_size = 512
    encoder_decoder_config.decoder.intermediate_size = 2048
    encoder_decoder_config.decoder.max_length = args.max_length
    encoder_decoder_config.decoder.max_position_embeddings = args.max_length
    encoder_decoder_config.decoder.num_attention_heads = 8
    encoder_decoder_config.decoder.num_hidden_layers = 6
    encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id
    encoder_decoder_config.decoder.type_vocab_size = 1
    encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size

    trans_model = EncoderDecoderModel(config=encoder_decoder_config)
    nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir)

    model = MyModel(nougat_model.config, trans_model, nougat_model)
    
    num_gpu = torch.cuda.device_count()
    gradient_accumulation_steps = args.batch_size // (num_gpu * args.batch_size_per_gpu)

    training_args = TrainingArguments(
        output_dir=os.path.join(args.base_dir, 'models'),
        per_device_train_batch_size=args.batch_size_per_gpu,
        per_device_eval_batch_size=args.batch_size_per_gpu,
        gradient_accumulation_steps=gradient_accumulation_steps,
        logging_strategy='steps',
        logging_steps=1,
        evaluation_strategy='steps',
        eval_steps=args.eval_steps,
        save_strategy='steps',
        save_steps=args.save_steps,
        fp16=args.fp16,
        learning_rate=args.learning_rate,
        max_steps=args.max_steps,
        warmup_steps=args.warmup_steps,
        dataloader_num_workers=args.dataloader_num_workers,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
    )

    trainer.train()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", type=str)
    parser.add_argument("--dataset_dir", type=str)
    parser.add_argument("--donut_dir", type=str)
    parser.add_argument("--nougat_dir", type=str)
    
    parser.add_argument("--max_length", type=int, default=1536)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--batch_size_per_gpu", type=int, default=4)
    parser.add_argument("--eval_steps", type=int, default=1000)
    parser.add_argument("--save_steps", type=int, default=1000)
    parser.add_argument("--fp16", type=bool, default=True)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--max_steps", type=int, default=10000)
    parser.add_argument("--warmup_steps", type=int, default=1000)
    parser.add_argument("--dataloader_num_workers", type=int, default=8)
    
    args = parser.parse_args()
    
    train(args)