Spaces:
Sleeping
Sleeping
File size: 4,459 Bytes
ec0c384 00b05e0 d0d6669 e76a04b d0d6669 e76a04b d0d6669 e76a04b d0d6669 e76a04b 00b05e0 e76a04b 00b05e0 d0d6669 00b05e0 e76a04b 00b05e0 d0d6669 4948600 ec0c384 4948600 e76a04b ec0c384 d0d6669 e76a04b d0d6669 e76a04b d0d6669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
from huggingface_hub import list_models
from typing import List
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
import re
import logging
from datasets import load_dataset
# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for Donut model, processor, and dataset
donut_model = None
donut_processor = None
dataset = None
def load_merit_dataset():
global dataset
if dataset is None:
dataset = load_dataset("de-Rodrigo/merit", name="en-digital-seq", split="train")
return dataset
def get_image_from_dataset(index):
global dataset
if dataset is None:
dataset = load_merit_dataset()
image_data = dataset[int(index)]["image"]
return image_data
def get_collection_models(tag: str) -> List[str]:
"""Get a list of models from a specific Hugging Face collection."""
models = list_models(author="de-Rodrigo")
return [model.modelId for model in models if tag in model.tags]
def get_donut():
global donut_model, donut_processor
if donut_model is None or donut_processor is None:
try:
donut_model = VisionEncoderDecoderModel.from_pretrained(
"de-Rodrigo/donut-merit"
)
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
if torch.cuda.is_available():
donut_model = donut_model.to("cuda")
logger.info("Donut model loaded successfully")
except Exception as e:
logger.error(f"Error loading Donut model: {str(e)}")
raise
return donut_model, donut_processor
def process_image_donut(model, processor, image):
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
pixel_values = processor(image, return_tensors="pt").pixel_values
if torch.cuda.is_available():
pixel_values = pixel_values.to("cuda")
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
)["input_ids"]
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, ""
)
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
result = processor.token2json(sequence)
return json.dumps(result, indent=2)
except Exception as e:
logger.error(f"Error processing image with Donut: {str(e)}")
return f"Error: {str(e)}"
def process_image(model_name, image=None, dataset_image_index=None):
if dataset_image_index is not None:
image = get_image_from_dataset(dataset_image_index)
if model_name == "de-Rodrigo/donut-merit":
model, processor = get_donut()
result = process_image_donut(model, processor, image)
else:
# Here you should implement processing for other models
result = f"Processing for model {model_name} not implemented"
return image, result
if __name__ == "__main__":
# Load the dataset
load_merit_dataset()
models = get_collection_models("saliency")
models.append("de-Rodrigo/donut-merit")
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Dropdown(choices=models, label="Select Model"),
gr.Image(type="pil", label="Upload Image"),
gr.Slider(
minimum=0, maximum=len(dataset) - 1, step=1, label="Dataset Image Index"
),
],
outputs=[gr.Image(label="Processed Image"), gr.Textbox(label="Result")],
title="Document Understanding with Donut",
description="Upload an image or select one from the dataset to process with the selected model.",
)
demo.launch()
|