File size: 4,662 Bytes
6581de9
5279e45
6581de9
 
43a5321
6581de9
 
43a5321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb030c
43a5321
 
 
 
 
 
 
 
 
 
6581de9
 
 
4bf6412
 
 
 
 
 
 
 
 
4b73e05
4bf6412
f9ce3f3
 
 
 
 
 
 
4b73e05
1f662e3
f7fe7ff
df73b43
 
c91f43f
ef55841
c91f43f
 
 
 
 
 
ef55841
c91f43f
 
 
 
 
 
 
 
 
 
 
 
 
 
df73b43
c91f43f
4b73e05
c91f43f
f756684
71bf396
c91f43f
f756684
71bf396
f9ce3f3
71bf396
f9ce3f3
71bf396
c91f43f
4b73e05
71bf396
f9ce3f3
71bf396
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import streamlit as st

from PIL import Image
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor


def run_prediction(sample):
    global pretrained_model, processor, task_prompt
    if isinstance(sample, dict):
        # prepare inputs
        pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
    else:  # sample is an image
        # prepare encoder inputs
        pixel_values = processor(image, return_tensors="pt").pixel_values
    
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # run inference
    outputs = pretrained_model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=pretrained_model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    prediction = processor.batch_decode(outputs.sequences)[0]
    
    # post-processing
    if "cord" in task_prompt:
        prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        # prediction = re.sub(r"<.*?>", "", prediction, count=1).strip()  # remove first task start token
    prediction = processor.token2json(prediction)
    
    # load reference target
    if isinstance(sample, dict):
        target = processor.token2json(sample["target_sequence"])
    else:
        target = "<not_provided>"
    
    return prediction, target
    

task_prompt = f"<s>"

st.text('''
This is OCR-free Document Understanding Transformer nicknamed 🍩. It was fine-tuned with 1000 receipt images -> SROIE dataset.
The original 🍩 implementation can be found on: https://github.com/clovaai/donut
''')

with st.sidebar:
    information = st.radio(
    "What information inside the are you interested in?",
    ('Receipt Summary', 'Receipt Menu Details', 'Extract all!'))
    receipt = st.selectbox('Pick one 🧾', ['1', '2', '3', '4', '5', '6'], index=5)

    # file upload
    # uploaded_file = st.file_uploader("Choose a file")
    # if uploaded_file is not None:
        ## To read file as bytes:
        # bytes_data = uploaded_file.getvalue()
        # st.write(bytes_data)

st.text(f'{information} mode is ON!\nTarget 🧾: {receipt}\n(opening image @:./img/receipt-{receipt}.png)')

image = Image.open(f"./img/receipt-{receipt}.jpg")
st.image(image, caption='Your target receipt')

st.text(f'baking the 🍩s...')

if information == 'Receipt Summary':
    processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
    pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
    task_prompt = f"<s>"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pretrained_model.to(device)

elif information == 'Receipt Menu Details':
    processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    pretrained_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    task_prompt = f"<s_cord-v2>"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pretrained_model.to(device)
    
else:
    processor_a = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
    processor_b = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    pretrained_model_a = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
    pretrained_model_b = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    
    device = "cuda" if torch.cuda.is_available() else "cpu"

if information == 'Extract all!':
    st.text(f'parsing 🧾 (extracting all)...')
    pretrained_model, processor, task_prompt = pretrained_model_a, processor_a, f"<s>"
    pretrained_model.to(device)
    parsed_receipt_info_a, _ = run_prediction(image)
    pretrained_model, processor, task_prompt = pretrained_model_b, processor_b, f"<s_cord-v2>"
    pretrained_model.to(device)
    parsed_receipt_info_b, _ = run_prediction(image)
    st.text(f'\nReceipt Summary:')
    st.json(parsed_receipt_info_a)
    st.text(f'\nReceipt Menu Details:')
    st.json(parsed_receipt_info_b)
else:
    st.text(f'parsing 🧾...')
    parsed_receipt_info, _ = run_prediction(image)
    st.text(f'\n{information}')
    st.json(parsed_receipt_info)