File size: 6,095 Bytes
bc391b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
credentials_kwargs={"aws_access_key_id": os.environ["ACCESS_KEY"],"aws_secret_access_key": os.environ["SECRET_KEY"]}

# work around: https://discuss.huggingface.co/t/how-to-install-a-specific-version-of-gradio-in-spaces/13552
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.4.1")
os.system(os.environ["DD_ADDONS"])

import time
from os import getcwd, path

import deepdoctection as dd
from deepdoctection.dataflow.serialize import DataFromList
from deepdoctection.utils.settings import get_type

from dd_addons.analyzer.loader import get_loader
from dd_addons.extern.guidance import TOKEN_DEFAULT_INSTRUCTION
from dd_addons.utils.settings import register_llm_token_tag, register_string_categories_from_list
from dd_addons.extern.openai import OpenAiLmmTokenClassifier

import gradio as gr

analyzer = get_loader(reset_config_file=True)

demo = gr.Blocks(css="scrollbar.css")


def process_analyzer(openai_api_key, categories_str, instruction_str, img, pdf, max_datapoints):
    categories_list = categories_str.split(",")
    register_string_categories_from_list(categories_list, "custom_token_classes")
    custom_token_class = dd.object_types_registry.get("custom_token_classes")
    print([token_class for token_class in custom_token_class])
    register_llm_token_tag([token_class for token_class in custom_token_class])
    categories = {
        str(idx + 1): get_type(val) for idx, val in enumerate(categories_list)
    }

    gpt_token_classifier = OpenAiLmmTokenClassifier(
        model_name="gpt-3.5-turbo",
        categories=categories,
        api_key=openai_api_key,
        instruction= instruction_str if instruction_str else None,
    )
    analyzer.pipe_component_list[8].language_model = gpt_token_classifier

    if img is not None:
        image = dd.Image(file_name=str(time.time()).replace(".","") + ".png", location="")
        image.image = img[:, :, ::-1]

        df = DataFromList(lst=[image])
        df = analyzer.analyze(dataset_dataflow=df)
    elif pdf:
        df = analyzer.analyze(path=pdf.name, max_datapoints=max_datapoints)
    else:
        raise ValueError

    df.reset_state()

    json_out = {}
    dpts = []

    for idx, dp in enumerate(df):
        dpts.append(dp)
        json_out[f"page_{idx}"] = dp.get_token()

    return [dp.viz(show_cells=False, show_layouts=False, show_tables=False, show_words=True, show_token_class=True, ignore_default_token_class=True)
            for dp in dpts], json_out


with demo:
    with gr.Box():
        gr.Markdown("<h1><center>Document AI GPT</center></h1>")
        gr.Markdown("<h2 ><center>Zero or few-shot Entity Extraction powered by ChatGPT and <strong>deep</strong>doctection </center></h2>"
                    "<center>This pipeline consists of a stack of models powered for layout analysis and table recognition "
                    "to prepare a prompt for ChatGPT. </center>"
                    "<center>Be aware! The Space is still very fragile.</center><br />")
    with gr.Box():
        gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
        with gr.Row():
            with gr.Column():
                with gr.Tab("Image upload"):
                    with gr.Column():
                        inputs = gr.Image(type='numpy', label="Original Image")
                with gr.Tab("PDF upload *"):
                    with gr.Column():
                        inputs_pdf = gr.File(label="PDF")
                    gr.Markdown("<sup>* If an image is cached in tab, remove it first</sup>")
                with gr.Box():
                    gr.Examples(
                        examples=[path.join(getcwd(), "sample_2.png")],
                        inputs = inputs)
                with gr.Box():
                    gr.Markdown("Enter your OpenAI API Key* ")
                    user_token = gr.Textbox(value='', placeholder="OpenAI API Key", type="password", show_label=False)
                    gr.Markdown("<sup>* Your API key will not be saved. However, it is always recommended to deactivate the"
                                "API key once it is entered into an unknown source</sup>")
            with gr.Column():
                with gr.Box():
                    gr.Markdown(
                        "Enter a list of comma seperated entities. Use a snake case style. Avoid special characters. "
                        "Best way is to only use `a-z` and `_`")
                    categories = gr.Textbox(value='', placeholder="mitarbeiter_anzahl", show_label=False)
                with gr.Box():
                    gr.Markdown("Optional: Enter a prompt for additional guidance. Will use the placeholder as fallback")
                    instruction = gr.Textbox(value='', placeholder=TOKEN_DEFAULT_INSTRUCTION, show_label=False)
        with gr.Row():
            max_imgs = gr.Slider(1, 3, value=1, step=1, label="Number of pages in multi page PDF",
                                 info="Will stop after 3 pages")

        with gr.Row():
            btn = gr.Button("Run model", variant="primary")

    with gr.Box():
        gr.Markdown("<h2><center>Outputs</center></h2>")
        with gr.Row():
            with gr.Column():
                with gr.Box():
                    gr.Markdown("<center><strong>JSON</strong></center>")
                    json = gr.JSON()
            with gr.Column():
                with gr.Box():
                    gr.Markdown("<center><strong>Layout detection</strong></center>")
                    gallery = gr.Gallery(
                        label="Output images", show_label=False, elem_id="gallery"
                    ).style(grid=2)
        with gr.Row():
            with gr.Box():
                gr.Markdown("<center><strong>Table</strong></center>")
                html = gr.HTML()

    btn.click(fn=process_analyzer, inputs=[user_token, categories,  instruction, inputs, inputs_pdf, max_imgs],
              outputs=[gallery, json])

demo.launch()