saliencies / app.py
de-Rodrigo's picture
Improve GUI
a62b0d6
raw
history blame
5.45 kB
import spaces
import gradio as gr
from huggingface_hub import list_models
from typing import List
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
import re
import logging
from datasets import load_dataset
# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for Donut model, processor, and dataset
donut_model = None
donut_processor = None
dataset = None
def load_merit_dataset():
global dataset
if dataset is None:
dataset = load_dataset(
"de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8
)
return dataset
def get_image_from_dataset(index):
global dataset
if dataset is None:
dataset = load_merit_dataset()
image_data = dataset[int(index)]["image"]
return image_data
def get_collection_models(tag: str) -> List[str]:
"""Get a list of models from a specific Hugging Face collection."""
models = list_models(author="de-Rodrigo")
return [model.modelId for model in models if tag in model.tags]
@spaces.GPU
def get_donut():
global donut_model, donut_processor
if donut_model is None or donut_processor is None:
try:
donut_model = VisionEncoderDecoderModel.from_pretrained(
"de-Rodrigo/donut-merit"
)
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
donut_model = donut_model.to("cuda")
logger.info("Donut model loaded successfully on GPU")
except Exception as e:
logger.error(f"Error loading Donut model: {str(e)}")
raise
return donut_model, donut_processor
@spaces.GPU
def process_image_donut(model, processor, image):
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to("cuda")
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, ""
)
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
result = processor.token2json(sequence)
return json.dumps(result, indent=2)
except Exception as e:
logger.error(f"Error processing image with Donut: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def process_image(model_name, image=None, dataset_image_index=None):
if dataset_image_index is not None:
image = get_image_from_dataset(dataset_image_index)
if model_name == "de-Rodrigo/donut-merit":
model, processor = get_donut()
result = process_image_donut(model, processor, image)
else:
# Here you should implement processing for other models
result = f"Processing for model {model_name} not implemented"
return image, result
def update_image(dataset_image_index):
return get_image_from_dataset(dataset_image_index)
if __name__ == "__main__":
# Load the dataset
load_merit_dataset()
models = get_collection_models("saliency")
models.append("de-Rodrigo/donut-merit")
with gr.Blocks() as demo:
gr.Markdown("# Document Understanding with Donut")
gr.Markdown(
"Select a model and an image from the dataset, or upload your own image."
)
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(choices=models, label="Select Model")
dataset_slider = gr.Slider(
minimum=0,
maximum=len(dataset) - 1,
step=1,
label="Dataset Image Index",
)
upload_image = gr.Image(type="pil", label="Or Upload Your Own Image")
preview_image = gr.Image(label="Selected/Uploaded Image")
process_button = gr.Button("Process Image")
with gr.Row():
output_image = gr.Image(label="Processed Image")
output_text = gr.Textbox(label="Result")
# Update preview image when slider changes
dataset_slider.change(
fn=update_image, inputs=[dataset_slider], outputs=[preview_image]
)
# Update preview image when an image is uploaded
upload_image.change(
fn=lambda x: x, inputs=[upload_image], outputs=[preview_image]
)
# Process image when button is clicked
process_button.click(
fn=process_image,
inputs=[model_dropdown, upload_image, dataset_slider],
outputs=[output_image, output_text],
)
demo.launch()