File size: 3,685 Bytes
064752a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import json
import torch
from transformers import DonutProcessor, AutoTokenizer
import argparse
from transformers import VisionEncoderDecoderModel, EncoderDecoderModel, EncoderDecoderConfig, BertConfig
from my_model import MyModel, MyDataset
from transformers import GenerationConfig
from PIL import Image

def inference(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    processor = DonutProcessor.from_pretrained(args.donut_dir)
    processor.image_processor.size = {'height': 896, 'width': 672}
    processor.image_processor.image_mean = [0.485, 0.456, 0.406]
    processor.image_processor.image_std = [0.229, 0.224, 0.225]
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.base_dir, 'zh_tokenizer'))

    encoder_config = BertConfig()
    decoder_config = BertConfig()
    encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
    encoder_decoder_config.decoder.bos_token_id = tokenizer.bos_token_id
    encoder_decoder_config.decoder.decoder_start_token_id = tokenizer.bos_token_id
    encoder_decoder_config.decoder.eos_token_id = tokenizer.eos_token_id
    encoder_decoder_config.decoder.hidden_size = 512
    encoder_decoder_config.decoder.intermediate_size = 2048
    encoder_decoder_config.decoder.max_length = args.max_length
    encoder_decoder_config.decoder.max_position_embeddings = args.max_length
    encoder_decoder_config.decoder.num_attention_heads = 8
    encoder_decoder_config.decoder.num_hidden_layers = 6
    encoder_decoder_config.decoder.pad_token_id = tokenizer.pad_token_id
    encoder_decoder_config.decoder.type_vocab_size = 1
    encoder_decoder_config.decoder.vocab_size = tokenizer.vocab_size

    trans_model = EncoderDecoderModel(config=encoder_decoder_config)
    nougat_model = VisionEncoderDecoderModel.from_pretrained(args.nougat_dir)

    model = MyModel(nougat_model.config, trans_model, nougat_model)

    checkpoint_file_path = os.path.join(args.checkpoint_dir, 'pytorch_model.bin')
    checkpoint = torch.load(checkpoint_file_path, map_location='cpu')
    model.load_state_dict(checkpoint)
    model.eval()
    model.to(device)

    generation_config = GenerationConfig(
        max_length=args.max_length,
        early_stopping=True,
        num_beams=args.num_beams,
        use_cache=True,
        length_penalty=1.0,
        bos_token_id=tokenizer.bos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    
    image = Image.open(args.image_file_path)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    generation_ids = model.generate(
        pixel_values=pixel_values,
        generation_config=generation_config,
    )
    
    zh_text = tokenizer.decode(generation_ids[0])
    
    result_dir = os.path.join(args.base_dir, 'outputs')
    os.makedirs(result_dir, exist_ok=True)
    
    result_file_path = os.path.join(result_dir, args.image_file_path.split('/')[-1][:-4]+'.txt')
    with open(result_file_path, 'w', encoding='utf-8') as f:
        f.write(zh_text)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", type=str)
    parser.add_argument("--donut_dir", type=str)
    parser.add_argument("--nougat_dir", type=str)
    parser.add_argument("--checkpoint_dir", type=str)
    parser.add_argument("--image_file_path", type=str)
    
    parser.add_argument("--max_length", type=int, default=1536)
    parser.add_argument("--num_beams", type=int, default=4)

    args = parser.parse_args()
    
    inference(args)