vincentclaes's picture
first working version
05957fd
raw
history blame
16.5 kB
import io
import os
import boto3
import traceback
import gradio as gr
from PIL import Image, ImageDraw
from docquery.document import load_document, ImageDocument
from docquery.ocr_reader import get_ocr_reader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers import DonutProcessor, VisionEncoderDecoderModel
# avoid ssl errors
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def ensure_list(x):
if isinstance(x, list):
return x
else:
return [x]
CHECKPOINTS = {
# "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
# "LayoutLMv1 for Invoices 💸": "impira/layoutlm-invoices",
"Textract Query": "Textract",
"LayoutLM FineTuned": "LayoutLM FineTuned",
"Donut": "naver-clova-ix/donut-base-finetuned-rvlcdip",
"LiLT": "philschmid/lilt-en-funsd",
# "LiLT" : "nielsr/lilt-xlm-roberta-base"
}
PIPELINES = {}
#
#
# def construct_pipeline(task, model):
# global PIPELINES
# if model in PIPELINES:
# return PIPELINES[model]
#
# device = "cuda" if torch.cuda.is_available() else "cpu"
# ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
# PIPELINES[model] = ret
# return ret
def image_to_byte_array(image: Image) -> bytes:
image_as_byte_array = io.BytesIO()
image.save(image_as_byte_array, format="PNG")
image_as_byte_array = image_as_byte_array.getvalue()
return image_as_byte_array
def run_textract_query(question, document):
image_as_byte_base64 = image_to_byte_array(image=document.b)
response = boto3.client('textract').analyze_document(
Document={
'Bytes': image_as_byte_base64,
},
FeatureTypes=[
'QUERIES',
],
QueriesConfig={
'Queries': [
{
'Text': question,
'Pages': [
'*',
]
},
]
}
)
for element in response["Blocks"]:
if element["BlockType"] == "QUERY_RESULT":
return {
"score": element["Confidence"],
"answer": element["Text"],
# "word_ids": element
}
else:
Exception("No QUERY_RESULT found in the response from Textract.")
def run_layoutlm_finetuned(question, document):
from transformers import pipeline
nlp = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
result = nlp(document.context["image"][0][0], question)[0]
# [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
return {
"score": result["score"],
"answer": result["answer"],
"word_ids": [result["start"], result["end"]],
"page": 0
}
def run_lilt_model(question, document):
# use this model + tokenizer
lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
model = AutoModelForQuestionAnswering.from_pretrained("nielsr/lilt-xlm-roberta-base")
processed_document = document.context["image"][0][1]
words = [x[0] for x in processed_document]
boxes = [x[1] for x in processed_document]
encoding = lilt_tokenizer(text=question, text_pair=words, boxes=boxes, add_special_tokens=True, return_tensors="pt")
outputs = model(**encoding)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = encoding.input_ids[0, answer_start_index: answer_end_index + 1]
predict_answer = lilt_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
return {
"score": "n/a",
"answer": predict_answer,
# "word_ids": element
}
def run_donut(question, document):
# nlp = pipeline(
# "document-question-answering",
# model="naver-clova-ix/donut-base-finetuned-docvqa",
# )
#
# result = nlp(document.context["image"][0][0], question)[0]
# # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
# return {
# "score": result["score"],
# "answer": result["answer"],
# "word_ids": [result["start"], result["end"]],
# "page": 0
# }
donut_processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
donut_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
# prepare encoder inputs
pixel_values = donut_processor(document.context["image"][0][0], return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = donut_processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = donut_model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=donut_model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=donut_processor.tokenizer.pad_token_id,
eos_token_id=donut_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
import re
# postprocess
sequence = donut_processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
result = donut_processor.token2json(sequence)
return {
"score": "n/a",
"answer": result["answer"],
# "word_ids": element
}
def run_pipeline(model, question, document, top_k):
""" Run pipeline selected by the user.
:return: expect an object like
[{'score': 0.251716673374176, 'answer': 'CREDIT', 'word_ids': [38], 'page': 0},
{'score': 0.15292450785636902, 'answer': 'LETTER OF CREDIT', 'word_ids': [37, 38], 'page': 0},
{'score': 0.009600160643458366, 'answer': 'Payment Tens LETTER OF CREDIT', 'word_ids': [36, 37, 38], 'page': 0}]
"""
if model == "Textract Query":
return run_textract_query(question, document)
elif model == "LiLT":
return run_lilt_model(question, document)
elif model == "LayoutLM FineTuned":
return run_layoutlm_finetuned(question=question, document=document)
elif model == "Donut":
return run_donut(question=question, document=document)
else:
return {"answer": "model not found", "score": "n/a"}
def process_path(path):
error = None
if path:
try:
document = load_document(path)
return (
document,
gr.update(visible=True, value=document.preview),
gr.update(visible=True),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
except Exception as e:
traceback.print_exc()
error = str(e)
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=True, value=error) if error is not None else None,
None,
)
def process_upload(file):
if file:
return process_path(file.name)
else:
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
def lift_word_boxes(document, page):
return document.context["image"][page][1]
def expand_bbox(word_boxes):
if len(word_boxes) == 0:
return None
min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
return [min_x, min_y, max_x, max_y]
# LayoutLM boxes are normalized to 0, 1000
def normalize_bbox(box, width, height, padding=0.005):
min_x, min_y, max_x, max_y = [c / 1000 for c in box]
if padding != 0:
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(max_x + padding, 1)
max_y = min(max_y + padding, 1)
return [min_x * width, min_y * height, max_x * width, max_y * height]
def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
prediction = run_pipeline(model, question, document, 3)
pages = [x.copy().convert("RGB") for x in document.preview]
text_value = prediction["answer"]
if "word_ids" in prediction:
image = pages[prediction["page"]]
draw = ImageDraw.Draw(image, "RGBA")
word_boxes = lift_word_boxes(document, prediction["page"])
x1, y1, x2, y2 = normalize_bbox(
expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
image.width,
image.height,
)
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
return (
gr.update(visible=True, value=pages),
gr.update(visible=True, value=prediction),
gr.update(
visible=True,
value=text_value,
),
)
def load_example_document(img, question, model):
if img is not None:
document = ImageDocument(Image.fromarray(img), get_ocr_reader())
preview, answer, answer_text = process_question(question, document, model)
return document, question, preview, gr.update(visible=True), answer, answer_text
else:
return None, None, None, gr.update(visible=False), None, None
CSS = """
#question input {
font-size: 16px;
}
#url-textbox {
padding: 0 !important;
}
#short-upload-box .w-full {
min-height: 10rem !important;
}
/* I think something like this can be used to re-shape
* the table
*/
/*
.gr-samples-table tr {
display: inline;
}
.gr-samples-table .p-2 {
width: 100px;
}
*/
#select-a-file {
width: 100%;
}
#file-clear {
padding-top: 2px !important;
padding-bottom: 2px !important;
padding-left: 8px !important;
padding-right: 8px !important;
margin-top: 10px;
}
.gradio-container .gr-button-primary {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700;
}
.gradio-container.dark button#submit-button {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700
}
table.gr-samples-table tr td {
border: none;
outline: none;
}
table.gr-samples-table tr td:first-of-type {
width: 0%;
}
div#short-upload-box div.absolute {
display: none !important;
}
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
gap: 0px 2%;
}
gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
gap: 0px;
}
gradio-app h2, .gradio-app h2 {
padding-top: 10px;
}
#answer {
overflow-y: scroll;
color: white;
background: #666;
border-color: #666;
font-size: 20px;
font-weight: bold;
}
#answer span {
color: white;
}
#answer textarea {
color:white;
background: #777;
border-color: #777;
font-size: 18px;
}
#url-error input {
color: red;
}
"""
examples = [
[
"scenario-1.png",
"What is the final consignee?",
],
[
"scenario-1.png",
"What are the payment terms?",
],
[
"scenario-2.png",
"What is the actual manufacturer?",
],
[
"scenario-3.png",
'What is the "ship to" destination?',
],
[
"scenario-4.png",
'What is the color?',
],
[
"scenario-5.png",
'What is the "said to contain"?',
],
[
"scenario-5.png",
'What is the "Net Weight"?',
],
[
"scenario-5.png",
'What is the "Freight Collect"?',
],
[
"bill_of_lading_1.png",
"What is the shipper?",
],
[
"bill_of_lading_1.png",
"What is the consignee?",
],
[
"bill_of_lading_1.png",
"What is the consignee id?",
],
[
"bill_of_lading_1.png",
"What is the carrier id?",
],
[
"bill_of_lading_1.png",
"What is the description of the products?",
],
[
"bill_of_lading_1.png",
"What is the quantity of the products?",
],
]
with gr.Blocks(css=CSS) as demo:
gr.Markdown("# Document Query Engine")
gr.Markdown(
"Original version comes from DocQuery [here](https://huggingface.co/spaces/impira/docquery) (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=docquery_space))"
)
document = gr.Variable()
example_question = gr.Textbox(visible=False)
example_image = gr.Image(visible=False)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
gr.Markdown("## 1. Select a file", elem_id="select-a-file")
img_clear_button = gr.Button(
"Clear", variant="secondary", elem_id="file-clear", visible=False
)
image = gr.Gallery(visible=False)
upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
gr.Examples(
examples=examples,
inputs=[example_image, example_question],
)
with gr.Column() as col:
gr.Markdown("## 2. Ask a question")
question = gr.Textbox(
label="Question",
placeholder="e.g. What is the invoice number?",
lines=1,
max_lines=1,
)
model = gr.Radio(
choices=list(CHECKPOINTS.keys()),
value=list(CHECKPOINTS.keys())[0],
label="Model",
)
with gr.Row():
clear_button = gr.Button("Clear", variant="secondary")
submit_button = gr.Button(
"Submit", variant="primary", elem_id="submit-button"
)
with gr.Column():
output_text = gr.Textbox(
label="Top Answer", visible=False, elem_id="answer"
)
output = gr.JSON(label="Output", visible=False)
for cb in [img_clear_button, clear_button]:
cb.click(
lambda _: (
gr.update(visible=False, value=None),
None,
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False),
None,
None,
None,
gr.update(visible=False, value=None),
None,
),
inputs=clear_button,
outputs=[
image,
document,
output,
output_text,
img_clear_button,
example_image,
upload,
question,
],
)
upload.change(
fn=process_upload,
inputs=[upload],
outputs=[document, image, img_clear_button, output, output_text],
)
question.submit(
fn=process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
submit_button.click(
process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
model.change(
process_question,
inputs=[question, document, model],
outputs=[image, output, output_text],
)
example_image.change(
fn=load_example_document,
inputs=[example_image, example_question, model],
outputs=[document, question, image, img_clear_button, output, output_text],
)
if __name__ == "__main__":
demo.launch(enable_queue=False)