Spaces:
Runtime error
Runtime error
File size: 4,662 Bytes
6581de9 5279e45 6581de9 43a5321 6581de9 43a5321 ebb030c 43a5321 6581de9 4bf6412 4b73e05 4bf6412 f9ce3f3 4b73e05 1f662e3 f7fe7ff df73b43 c91f43f ef55841 c91f43f ef55841 c91f43f df73b43 c91f43f 4b73e05 c91f43f f756684 71bf396 c91f43f f756684 71bf396 f9ce3f3 71bf396 f9ce3f3 71bf396 c91f43f 4b73e05 71bf396 f9ce3f3 71bf396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
import streamlit as st
from PIL import Image
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor
def run_prediction(sample):
global pretrained_model, processor, task_prompt
if isinstance(sample, dict):
# prepare inputs
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
else: # sample is an image
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# run inference
outputs = pretrained_model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=pretrained_model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# process output
prediction = processor.batch_decode(outputs.sequences)[0]
# post-processing
if "cord" in task_prompt:
prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
# prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
prediction = processor.token2json(prediction)
# load reference target
if isinstance(sample, dict):
target = processor.token2json(sample["target_sequence"])
else:
target = "<not_provided>"
return prediction, target
task_prompt = f"<s>"
st.text('''
This is OCR-free Document Understanding Transformer nicknamed 🍩. It was fine-tuned with 1000 receipt images -> SROIE dataset.
The original 🍩 implementation can be found on: https://github.com/clovaai/donut
''')
with st.sidebar:
information = st.radio(
"What information inside the are you interested in?",
('Receipt Summary', 'Receipt Menu Details', 'Extract all!'))
receipt = st.selectbox('Pick one 🧾', ['1', '2', '3', '4', '5', '6'], index=5)
# file upload
# uploaded_file = st.file_uploader("Choose a file")
# if uploaded_file is not None:
## To read file as bytes:
# bytes_data = uploaded_file.getvalue()
# st.write(bytes_data)
st.text(f'{information} mode is ON!\nTarget 🧾: {receipt}\n(opening image @:./img/receipt-{receipt}.png)')
image = Image.open(f"./img/receipt-{receipt}.jpg")
st.image(image, caption='Your target receipt')
st.text(f'baking the 🍩s...')
if information == 'Receipt Summary':
processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
task_prompt = f"<s>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
elif information == 'Receipt Menu Details':
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
task_prompt = f"<s_cord-v2>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
else:
processor_a = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
processor_b = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
pretrained_model_a = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model_b = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
if information == 'Extract all!':
st.text(f'parsing 🧾 (extracting all)...')
pretrained_model, processor, task_prompt = pretrained_model_a, processor_a, f"<s>"
pretrained_model.to(device)
parsed_receipt_info_a, _ = run_prediction(image)
pretrained_model, processor, task_prompt = pretrained_model_b, processor_b, f"<s_cord-v2>"
pretrained_model.to(device)
parsed_receipt_info_b, _ = run_prediction(image)
st.text(f'\nReceipt Summary:')
st.json(parsed_receipt_info_a)
st.text(f'\nReceipt Menu Details:')
st.json(parsed_receipt_info_b)
else:
st.text(f'parsing 🧾...')
parsed_receipt_info, _ = run_prediction(image)
st.text(f'\n{information}')
st.json(parsed_receipt_info) |