Spaces:
Runtime error
Runtime error
File size: 2,822 Bytes
7decfba ee97a5d 7decfba ee97a5d 7decfba 728c2e4 7decfba 728c2e4 7decfba ee97a5d 7decfba 728c2e4 7decfba 728c2e4 7decfba ee97a5d 7decfba 728c2e4 7decfba 728c2e4 7decfba ee97a5d 7decfba 728c2e4 7decfba 728c2e4 7decfba 9b395d3 9b65e39 7decfba 18d6320 7decfba 9b395d3 7decfba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import requests
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import spaces
@spaces.GPU
def infer_infographics(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
@spaces.GPU
def infer_ui(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")
inputs = processor(images=image,text=question, return_tensors="pt")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
@spaces.GPU
def infer_chart(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
@spaces.GPU
def infer_doc(image, question):
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
inputs = processor(images=image, text=question, return_tensors="pt")
predictions = model.generate(**inputs)
return processor.decode(predictions[0], skip_special_tokens=True)
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Pix2Struct 📄<center><h1>")
gr.HTML("<h3><center>Pix2Struct is a powerful backbone for visual question answering. ⚡</h3>")
gr.HTML("<h3><center>This app has base version of the model. For better performance, use large checkpoints.<h3>")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Answer")
gr.Examples(
[["docvqa_example.png", "How many items are sold?"]],
inputs = [input_img, question],
outputs = [output],
fn=infer_doc,
cache_examples=True,
label='Click on any Examples below to get Document Question Answering results quickly 👇'
)
submit_btn.click(infer_doc, [input_img, question], [output])
demo.launch(debug=True) |