Spaces:
Sleeping
Sleeping
File size: 5,447 Bytes
1dd7eb5 ec0c384 00b05e0 d0d6669 e76a04b d0d6669 e76a04b d0d6669 e76a04b d0d6669 e76a04b 00b05e0 e76a04b a62b0d6 e76a04b 00b05e0 d0d6669 00b05e0 e76a04b 00b05e0 d0d6669 4948600 ec0c384 4948600 e76a04b ec0c384 1dd7eb5 d0d6669 1dd7eb5 d0d6669 1dd7eb5 d0d6669 1dd7eb5 d0d6669 1dd7eb5 d0d6669 1dd7eb5 a62b0d6 d0d6669 a62b0d6 d0d6669 a62b0d6 d0d6669 a62b0d6 e76a04b a62b0d6 d0d6669 62028bb a62b0d6 62028bb a62b0d6 62028bb a62b0d6 62028bb a62b0d6 62028bb a62b0d6 62028bb a62b0d6 62028bb d0d6669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import spaces
import gradio as gr
from huggingface_hub import list_models
from typing import List
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
import re
import logging
from datasets import load_dataset
# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for Donut model, processor, and dataset
donut_model = None
donut_processor = None
dataset = None
def load_merit_dataset():
global dataset
if dataset is None:
dataset = load_dataset(
"de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8
)
return dataset
def get_image_from_dataset(index):
global dataset
if dataset is None:
dataset = load_merit_dataset()
image_data = dataset[int(index)]["image"]
return image_data
def get_collection_models(tag: str) -> List[str]:
"""Get a list of models from a specific Hugging Face collection."""
models = list_models(author="de-Rodrigo")
return [model.modelId for model in models if tag in model.tags]
@spaces.GPU
def get_donut():
global donut_model, donut_processor
if donut_model is None or donut_processor is None:
try:
donut_model = VisionEncoderDecoderModel.from_pretrained(
"de-Rodrigo/donut-merit"
)
donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
donut_model = donut_model.to("cuda")
logger.info("Donut model loaded successfully on GPU")
except Exception as e:
logger.error(f"Error loading Donut model: {str(e)}")
raise
return donut_model, donut_processor
@spaces.GPU
def process_image_donut(model, processor, image):
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to("cuda")
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, ""
)
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
result = processor.token2json(sequence)
return json.dumps(result, indent=2)
except Exception as e:
logger.error(f"Error processing image with Donut: {str(e)}")
return f"Error: {str(e)}"
@spaces.GPU
def process_image(model_name, image=None, dataset_image_index=None):
if dataset_image_index is not None:
image = get_image_from_dataset(dataset_image_index)
if model_name == "de-Rodrigo/donut-merit":
model, processor = get_donut()
result = process_image_donut(model, processor, image)
else:
# Here you should implement processing for other models
result = f"Processing for model {model_name} not implemented"
return image, result
def update_image(dataset_image_index):
return get_image_from_dataset(dataset_image_index)
if __name__ == "__main__":
# Load the dataset
load_merit_dataset()
models = get_collection_models("saliency")
models.append("de-Rodrigo/donut-merit")
with gr.Blocks() as demo:
gr.Markdown("# Document Understanding with Donut")
gr.Markdown(
"Select a model and an image from the dataset, or upload your own image."
)
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(choices=models, label="Select Model")
dataset_slider = gr.Slider(
minimum=0,
maximum=len(dataset) - 1,
step=1,
label="Dataset Image Index",
)
upload_image = gr.Image(type="pil", label="Or Upload Your Own Image")
preview_image = gr.Image(label="Selected/Uploaded Image")
process_button = gr.Button("Process Image")
with gr.Row():
output_image = gr.Image(label="Processed Image")
output_text = gr.Textbox(label="Result")
# Update preview image when slider changes
dataset_slider.change(
fn=update_image, inputs=[dataset_slider], outputs=[preview_image]
)
# Update preview image when an image is uploaded
upload_image.change(
fn=lambda x: x, inputs=[upload_image], outputs=[preview_image]
)
# Process image when button is clicked
process_button.click(
fn=process_image,
inputs=[model_dropdown, upload_image, dataset_slider],
outputs=[output_image, output_text],
)
demo.launch()
|