File size: 2,822 Bytes
7decfba
 
 
 
ee97a5d
7decfba
ee97a5d
7decfba
728c2e4
7decfba
 
728c2e4
7decfba
 
 
 
ee97a5d
7decfba
728c2e4
7decfba
 
728c2e4
7decfba
 
 
 
ee97a5d
7decfba
728c2e4
7decfba
 
728c2e4
7decfba
 
 
 
ee97a5d
7decfba
728c2e4
7decfba
728c2e4
7decfba
 
 
 
 
 
 
 
 
 
 
 
9b395d3
 
9b65e39
7decfba
 
 
 
 
18d6320
7decfba
 
 
 
 
 
 
 
 
9b395d3
7decfba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import requests
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import spaces

@spaces.GPU
def infer_infographics(image, question):
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base") 
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")

    inputs = processor(images=image, text=question, return_tensors="pt") 

    predictions = model.generate(**inputs)
    return processor.decode(predictions[0], skip_special_tokens=True)

@spaces.GPU
def infer_ui(image, question):
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base") 
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")

    inputs = processor(images=image,text=question, return_tensors="pt") 

    predictions = model.generate(**inputs)
    return processor.decode(predictions[0], skip_special_tokens=True)

@spaces.GPU
def infer_chart(image, question):
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base") 
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")

  inputs = processor(images=image, text=question, return_tensors="pt") 

  predictions = model.generate(**inputs)
  return processor.decode(predictions[0], skip_special_tokens=True)

@spaces.GPU
def infer_doc(image, question):
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base") 
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
  inputs = processor(images=image, text=question, return_tensors="pt") 
  predictions = model.generate(**inputs)
  return processor.decode(predictions[0], skip_special_tokens=True)

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>Pix2Struct 📄<center><h1>")
    gr.HTML("<h3><center>Pix2Struct is a powerful backbone for visual question answering. ⚡</h3>")
    gr.HTML("<h3><center>This app has base version of the model. For better performance, use large checkpoints.<h3>")

    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Document")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(value="Submit")
      output = gr.Text(label="Answer")
    gr.Examples(
    [["docvqa_example.png", "How many items are sold?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    cache_examples=True,
    label='Click on any Examples below to get Document Question Answering results quickly 👇'
    )
    
    submit_btn.click(infer_doc, [input_img, question], [output])

demo.launch(debug=True)