Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
from huggingface_hub import list_models | |
from typing import List | |
import torch | |
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
import json | |
import re | |
import logging | |
from datasets import load_dataset | |
# Logging configuration | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables for Donut model, processor, and dataset | |
donut_model = None | |
donut_processor = None | |
dataset = None | |
def load_merit_dataset(): | |
global dataset | |
if dataset is None: | |
dataset = load_dataset( | |
"de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8 | |
) | |
return dataset | |
def get_image_from_dataset(index): | |
global dataset | |
if dataset is None: | |
dataset = load_merit_dataset() | |
image_data = dataset[int(index)]["image"] | |
return image_data | |
def get_collection_models(tag: str) -> List[str]: | |
"""Get a list of models from a specific Hugging Face collection.""" | |
models = list_models(author="de-Rodrigo") | |
return [model.modelId for model in models if tag in model.tags] | |
def get_donut(): | |
global donut_model, donut_processor | |
if donut_model is None or donut_processor is None: | |
try: | |
donut_model = VisionEncoderDecoderModel.from_pretrained( | |
"de-Rodrigo/donut-merit" | |
) | |
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") | |
donut_model = donut_model.to("cuda") | |
logger.info("Donut model loaded successfully on GPU") | |
except Exception as e: | |
logger.error(f"Error loading Donut model: {str(e)}") | |
raise | |
return donut_model, donut_processor | |
def process_image_donut(model, processor, image): | |
try: | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer( | |
task_prompt, add_special_tokens=False, return_tensors="pt" | |
)["input_ids"].to("cuda") | |
outputs = model.generate( | |
pixel_values, | |
decoder_input_ids=decoder_input_ids, | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
processor.tokenizer.pad_token, "" | |
) | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
result = processor.token2json(sequence) | |
return json.dumps(result, indent=2) | |
except Exception as e: | |
logger.error(f"Error processing image with Donut: {str(e)}") | |
return f"Error: {str(e)}" | |
def process_image(model_name, image=None, dataset_image_index=None): | |
if dataset_image_index is not None: | |
image = get_image_from_dataset(dataset_image_index) | |
if model_name == "de-Rodrigo/donut-merit": | |
model, processor = get_donut() | |
result = process_image_donut(model, processor, image) | |
else: | |
# Here you should implement processing for other models | |
result = f"Processing for model {model_name} not implemented" | |
return image, result | |
def update_image(dataset_image_index): | |
return get_image_from_dataset(dataset_image_index) | |
if __name__ == "__main__": | |
# Load the dataset | |
load_merit_dataset() | |
models = get_collection_models("saliency") | |
models.append("de-Rodrigo/donut-merit") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Document Understanding with Donut") | |
gr.Markdown( | |
"Select a model and an image from the dataset, or upload your own image." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
model_dropdown = gr.Dropdown(choices=models, label="Select Model") | |
dataset_slider = gr.Slider( | |
minimum=0, | |
maximum=len(dataset) - 1, | |
step=1, | |
label="Dataset Image Index", | |
) | |
upload_image = gr.Image(type="pil", label="Or Upload Your Own Image") | |
preview_image = gr.Image(label="Selected/Uploaded Image") | |
process_button = gr.Button("Process Image") | |
with gr.Row(): | |
output_image = gr.Image(label="Processed Image") | |
output_text = gr.Textbox(label="Result") | |
# Update preview image when slider changes | |
dataset_slider.change( | |
fn=update_image, inputs=[dataset_slider], outputs=[preview_image] | |
) | |
# Update preview image when an image is uploaded | |
upload_image.change( | |
fn=lambda x: x, inputs=[upload_image], outputs=[preview_image] | |
) | |
# Process image when button is clicked | |
process_button.click( | |
fn=process_image, | |
inputs=[model_dropdown, upload_image, dataset_slider], | |
outputs=[output_image, output_text], | |
) | |
demo.launch() | |