File size: 4,344 Bytes
7e02e28
 
 
 
 
20f10b4
 
707702f
20f10b4
 
345110e
707702f
20f10b4
 
 
 
 
 
 
 
 
 
 
 
7e02e28
 
 
345110e
 
 
 
7e02e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6015b45
7e02e28
 
 
 
 
 
 
20f10b4
707702f
20f10b4
7e02e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from gradio_client import Client
import os
import json

import datasets

def save_to_dataset(image_path, question, answer_p2s_base, answer_p2s_large, answer_layoutlm, answer_donut):
    # Create a dataset dictionary
    dataset_dict = {
        "image": image_path,
        "question": question,
        "answer_p2s_base": answer_p2s_base, 
        "answer_p2s_large": answer_p2s_large, 
        "answer_layoutlm": answer_layoutlm, 
        "answer_donut": answer_donut
    }
    
    # Convert the dictionary to a Dataset object
    dataset = datasets.Dataset.from_dict(dataset_dict)
    
    # Save the dataset to Hugging Face
    dataset.save_to_disk("img_question_dataset")


def generate_answer(image_path, question, model_name, space_id):
    try:
        if model_name == "qtoino-pix2struct":
            client = Client(f"https://{model_name}.hf.space/--replicas/uax51/")
        else:
            client = Client(f"https://{model_name}.hf.space/")
        result = client.predict(image_path, question, api_name="/predict")
        if result.endswith(".json"):
            with open(result, "rb") as json_file:
                output = json.loads(json_file.read())
                if model_name == "TusharGoel-LayoutLM-DocVQA":
                    return output["label"]
                else:
                    return output["answer"]
        else:
            return result
    except Exception:
        gr.Warning(f"The {model_name} Space is currently unavailable. Please try again later.")
        return ""


def generate_answers(image_path, question):
    answer_p2s_base = generate_answer(image_path, question, model_name = "qtoino-pix2struct", space_id = "Pix2Struct")

    answer_p2s_large = generate_answer(image_path, question, model_name = "akdeniz27-pix2struct-DocVQA", space_id = "Pix2Struct Large")

    answer_layoutlm = generate_answer(image_path, question, model_name = "TusharGoel-LayoutLM-DocVQA", space_id = "LayoutLM DocVQA")

    answer_donut = generate_answer(image_path, question, model_name = "nielsr-donut-docvqa", space_id = "Donut DocVQA")

    # Save the data to the dataset
    save_to_dataset(image_path, question, answer_p2s_base, answer_p2s_large, answer_layoutlm, answer_donut)
    
    return answer_p2s_base, answer_p2s_large, answer_layoutlm, answer_donut
   
examples = [["docvqa_example.png", "How many items are sold?"], ["document-question-answering-input.png", "What is the objective?"]]

title = "# Interactive demo: comparing document question answering (VQA) models"

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>Compare Document Question Answering Models 📄<center><h1>")
    gr.HTML("<h3><center>Document question answering is the task of answering questions from documents in visual form. 📔📕</h3>")
    gr.HTML("<h3><center>To try this Space, simply upload documents and questions. </h3>")
    gr.HTML("<h3><center>If prompted to wait and try again, please try again. This Space uses other Spaces as APIs, so it might take time to get those Spaces up and running if they're stopped. </h3>")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label = "Input Document", type="filepath")
            question = gr.Textbox(label = "question")
            run_button = gr.Button("Answer")
        with gr.Column():
            out_p2s_base = gr.Textbox(label="Answer generated by Pix2Struct Base")
            out_p2s_large = gr.Textbox(label="Answer generated by Pix2Struct Large")
            out_layoutlm = gr.Textbox(label="Answer generated by LayoutLM")
            out_donut = gr.Textbox(label="Answer generated by Donut")


    outputs = [
        out_p2s_base,
        out_p2s_large,
        out_layoutlm,
        out_donut,
    ]

    gr.Examples(
        examples = [["docvqa_example.png", "How many items are sold?"], 
        ["document-question-answering-input.png", "What is the objective?"]],
        inputs=[input_image, question],
        outputs=outputs,
        fn=generate_answers,
        cache_examples=True
    )



    run_button.click(
        fn=generate_answers,
        inputs=[input_image,question],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.queue().launch(debug=True)