Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import list_models | |
from typing import List | |
import torch | |
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
import json | |
import re | |
import logging | |
from datasets import load_dataset | |
# Logging configuration | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables for Donut model, processor, and dataset | |
donut_model = None | |
donut_processor = None | |
dataset = None | |
def load_merit_dataset(): | |
global dataset | |
if dataset is None: | |
dataset = load_dataset("de-Rodrigo/merit", name="en-digital-seq", split="train") | |
return dataset | |
def get_image_from_dataset(index): | |
global dataset | |
if dataset is None: | |
dataset = load_merit_dataset() | |
image_data = dataset[int(index)]["image"] | |
return image_data | |
def get_collection_models(tag: str) -> List[str]: | |
"""Get a list of models from a specific Hugging Face collection.""" | |
models = list_models(author="de-Rodrigo") | |
return [model.modelId for model in models if tag in model.tags] | |
def get_donut(): | |
global donut_model, donut_processor | |
if donut_model is None or donut_processor is None: | |
try: | |
donut_model = VisionEncoderDecoderModel.from_pretrained( | |
"de-Rodrigo/donut-merit" | |
) | |
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") | |
if torch.cuda.is_available(): | |
donut_model = donut_model.to("cuda") | |
logger.info("Donut model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading Donut model: {str(e)}") | |
raise | |
return donut_model, donut_processor | |
def process_image_donut(model, processor, image): | |
try: | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
if torch.cuda.is_available(): | |
pixel_values = pixel_values.to("cuda") | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer( | |
task_prompt, add_special_tokens=False, return_tensors="pt" | |
)["input_ids"] | |
outputs = model.generate( | |
pixel_values, | |
decoder_input_ids=decoder_input_ids, | |
max_length=model.decoder.config.max_position_embeddings, | |
early_stopping=True, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=1, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
processor.tokenizer.pad_token, "" | |
) | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
result = processor.token2json(sequence) | |
return json.dumps(result, indent=2) | |
except Exception as e: | |
logger.error(f"Error processing image with Donut: {str(e)}") | |
return f"Error: {str(e)}" | |
def process_image(model_name, image=None, dataset_image_index=None): | |
if dataset_image_index is not None: | |
image = get_image_from_dataset(dataset_image_index) | |
if model_name == "de-Rodrigo/donut-merit": | |
model, processor = get_donut() | |
result = process_image_donut(model, processor, image) | |
else: | |
# Here you should implement processing for other models | |
result = f"Processing for model {model_name} not implemented" | |
return image, result | |
if __name__ == "__main__": | |
# Load the dataset | |
load_merit_dataset() | |
models = get_collection_models("saliency") | |
models.append("de-Rodrigo/donut-merit") | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Dropdown(choices=models, label="Select Model"), | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Slider( | |
minimum=0, maximum=len(dataset) - 1, step=1, label="Dataset Image Index" | |
), | |
], | |
outputs=[gr.Image(label="Processed Image"), gr.Textbox(label="Result")], | |
title="Document Understanding with Donut", | |
description="Upload an image or select one from the dataset to process with the selected model.", | |
) | |
demo.launch() | |