receipt-parser / app.py
laverdes's picture
fix: relative path to logo in markdown
9c9870b
raw
history blame
5.19 kB
import torch
import streamlit as st
from PIL import Image
from io import BytesIO
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor
def run_prediction(sample):
global pretrained_model, processor, task_prompt
if isinstance(sample, dict):
# prepare inputs
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
else: # sample is an image
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# run inference
outputs = pretrained_model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=pretrained_model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# process output
prediction = processor.batch_decode(outputs.sequences)[0]
# post-processing
if "cord" in task_prompt:
prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
# prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
prediction = processor.token2json(prediction)
# load reference target
if isinstance(sample, dict):
target = processor.token2json(sample["target_sequence"])
else:
target = "<not_provided>"
return prediction, target
task_prompt = f"<s>"
st.markdown("""
<h3 align="center">
![unstructured](./img/unstructured_logo.png "unstructured logo")
</h3>
""", unsafe_allow_html=True)
st.text('''
This is OCR-free Document Understanding Transformer nicknamed 🍩. It was fine-tuned with 1000 receipt images -> SROIE dataset.
The original 🍩 implementation can be found on: https://github.com/clovaai/donut
''')
image_upload = None
with st.sidebar:
information = st.radio(
"What information inside the are you interested in?",
('Receipt Summary', 'Receipt Menu Details', 'Extract all'))
receipt = st.selectbox('Pick one 🧾', ['1', '2', '3', '4', '5', '6'], index=1)
# file upload
uploaded_file = st.file_uploader("Upload a 🧾")
if uploaded_file is not None:
# To read file as bytes:
image_bytes_data = uploaded_file.getvalue()
image_upload = Image.open(BytesIO(image_bytes_data)) #.frombytes('RGBA', (128,128), image_bytes_data, 'raw')
# st.write(bytes_data)
st.text(f'{information} mode is ON!\nTarget 🧾: {receipt}') # \n(opening image @:./img/receipt-{receipt}.png)')
image = Image.open(f"./img/receipt-{receipt}.jpg")
col1, col2 = st.columns(2)
with col1:
if image_upload:
st.image(image_upload, caption='Your target receipt')
else:
st.image(image, caption='Your target receipt')
st.text(f'baking the 🍩s...')
if information == 'Receipt Summary':
processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
task_prompt = f"<s>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
elif information == 'Receipt Menu Details':
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
task_prompt = f"<s_cord-v2>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
else:
processor_a = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
processor_b = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
pretrained_model_a = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model_b = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
with col2:
if information == 'Extract all!':
st.text(f'parsing 🧾 (extracting all)...')
pretrained_model, processor, task_prompt = pretrained_model_a, processor_a, f"<s>"
pretrained_model.to(device)
parsed_receipt_info_a, _ = run_prediction(image)
pretrained_model, processor, task_prompt = pretrained_model_b, processor_b, f"<s_cord-v2>"
pretrained_model.to(device)
parsed_receipt_info_b, _ = run_prediction(image)
st.text(f'\nReceipt Summary:')
st.json(parsed_receipt_info_a)
st.text(f'\nReceipt Menu Details:')
st.json(parsed_receipt_info_b)
else:
st.text(f'parsing 🧾...')
parsed_receipt_info, _ = run_prediction(image)
st.text(f'\n{information}')
st.json(parsed_receipt_info)