UniChart / app.py
ahmed-masry's picture
Update app.py
c1c7155 verified
raw
history blame
3.05 kB
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import torch, os, re, json
import spaces
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/74801584018932.png', 'chart_example_1.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_2.png')
model_name = "ahmed-masry/unichart-base-960"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
@spaces.GPU
def predict(image, input_prompt):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_prompt += " <s_answer>"
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=4,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=2).strip()
return sequence
instructions = f"""
Demo of the [UniChart Base](https://huggingface.co/ahmed-masry/unichart-base-960) Model
Learn more about the model by reading [our paper](https://arxiv.org/abs/2305.14761) and explore the [code](https://github.com/vis-nlp/UniChart)
You can use UniChart for the following tasks:
| Task | Input Prompt |
| ------------- | ------------- |
| Chart Summarization | \<summarize_chart\> |
| Chart to Table | \<extract_data_table\> |
| Open Chart Question Answering | \<opencqa\> question |
"""
image = gr.components.Image(type="pil", label="Chart Image")
input_prompt = gr.components.Textbox(label="Input Prompt")
model_output = gr.components.Textbox(label="Model Output")
examples = [["chart_example_1.png", "<summarize_chart>"],
["chart_example_2.png", "<extract_data_table>"]]
title = "Interactive Gradio Demo for UniChart-base-960 model"
interface = gr.Interface(fn=predict,
inputs=[image, input_prompt],
outputs=model_output,
examples=examples,
title=title,
description=instructions,
theme='gradio/soft')
interface.launch()