import spaces import gradio as gr from huggingface_hub import list_models from typing import List import torch from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import json import re import logging from datasets import load_dataset # Logging configuration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for Donut model, processor, and dataset donut_model = None donut_processor = None dataset = None def load_merit_dataset(): global dataset if dataset is None: dataset = load_dataset( "de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8 ) return dataset def get_image_from_dataset(index): global dataset if dataset is None: dataset = load_merit_dataset() image_data = dataset[int(index)]["image"] return image_data def get_collection_models(tag: str) -> List[str]: """Get a list of models from a specific Hugging Face collection.""" models = list_models(author="de-Rodrigo") return [model.modelId for model in models if tag in model.tags] @spaces.GPU def get_donut(): global donut_model, donut_processor if donut_model is None or donut_processor is None: try: donut_model = VisionEncoderDecoderModel.from_pretrained( "de-Rodrigo/donut-merit" ) donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") donut_model = donut_model.to("cuda") logger.info("Donut model loaded successfully on GPU") except Exception as e: logger.error(f"Error loading Donut model: {str(e)}") raise return donut_model, donut_processor @spaces.GPU def process_image_donut(model, processor, image): try: if not isinstance(image, Image.Image): image = Image.fromarray(image) pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") task_prompt = "" decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" )["input_ids"].to("cuda") outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "" ) sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() result = processor.token2json(sequence) return json.dumps(result, indent=2) except Exception as e: logger.error(f"Error processing image with Donut: {str(e)}") return f"Error: {str(e)}" @spaces.GPU def process_image(model_name, image=None, dataset_image_index=None): if dataset_image_index is not None: image = get_image_from_dataset(dataset_image_index) if model_name == "de-Rodrigo/donut-merit": model, processor = get_donut() result = process_image_donut(model, processor, image) else: # Here you should implement processing for other models result = f"Processing for model {model_name} not implemented" return image, result def update_image(dataset_image_index): return get_image_from_dataset(dataset_image_index) if __name__ == "__main__": # Load the dataset load_merit_dataset() models = get_collection_models("saliency") models.append("de-Rodrigo/donut-merit") with gr.Blocks() as demo: gr.Markdown("# Document Understanding with Donut") gr.Markdown( "Select a model and an image from the dataset, or upload your own image." ) with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown(choices=models, label="Select Model") dataset_slider = gr.Slider( minimum=0, maximum=len(dataset) - 1, step=1, label="Dataset Image Index", ) upload_image = gr.Image(type="pil", label="Or Upload Your Own Image") preview_image = gr.Image(label="Selected/Uploaded Image") process_button = gr.Button("Process Image") with gr.Row(): output_image = gr.Image(label="Processed Image") output_text = gr.Textbox(label="Result") # Update preview image when slider changes dataset_slider.change( fn=update_image, inputs=[dataset_slider], outputs=[preview_image] ) # Update preview image when an image is uploaded upload_image.change( fn=lambda x: x, inputs=[upload_image], outputs=[preview_image] ) # Process image when button is clicked process_button.click( fn=process_image, inputs=[model_dropdown, upload_image, dataset_slider], outputs=[output_image, output_text], ) demo.launch()