Spaces:
Sleeping
Sleeping
File size: 6,598 Bytes
1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel
pretrained_repo_name = 'ivelin/donut-refexp-click'
pretrained_revision = 'main'
# revision can be git commit hash, branch or tag
# use 'main' for latest revision
print(f"Loading model checkpoint: {pretrained_repo_name}")
processor = DonutProcessor.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_refexp(image: Image, prompt: str):
print(f"(image, prompt): {image}, {prompt}")
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_center>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
seqjson = processor.token2json(sequence)
# safeguard in case predicted sequence does not include a target_center token
center_point = seqjson.get('target_center')
if center_point is None:
print(
f"predicted sequence has no target_center, seq:{sequence}")
center_point = {"x": 0, "y": 0}
return center_point
print(f"predicted center_point with text coordinates: {center_point}")
# safeguard in case text prediction is missing some center point coordinates
# or coordinates are not valid numeric values
try:
x = float(center_point.get("x", 0))
except ValueError:
x = 0
try:
y = float(center_point.get("y", 0))
except ValueError:
y = 0
# replace str with float coords
center_point = {"x": x, "y": y, "decoder output sequence": sequence}
print(f"predicted center_point with float coordinates: {center_point}")
print(f"image object: {image}")
print(f"image size: {image.size}")
width, height = image.size
print(f"image width, height: {width, height}")
print(f"processed prompt: {prompt}")
# safeguard in case text prediction is missing some center point coordinates
x = math.floor(width*center_point["x"])
y = math.floor(height*center_point["y"])
print(
f"to image pixel values: x, y: {x, y}")
# draw center point circle
img1 = ImageDraw.Draw(image)
r = 1
shape = [(x-r, y-r), (x+r, y+r)]
img1.ellipse(shape, outline="green", width=10)
img1.ellipse(shape, outline="white", width=5)
return image, center_point
title = "Demo: Donut 🍩 for UI RefExp - Center Point (by GuardianUI)"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
["example_1.jpg", "click on down arrow beside the entertainment"],
["example_1.jpg", "select the down arrow button beside lifestyle"],
["example_1.jpg", "click on the image beside the option traffic"],
["example_3.jpg", "select the third row first image"],
["example_3.jpg", "click the tick mark on the first image"],
["example_3.jpg", "select the ninth image"],
["example_3.jpg", "select the add icon"],
["example_3.jpg", "click the first image"],
["val-image-4.jpg", 'select 4153365454'],
['val-image-4.jpg', 'go to cell'],
['val-image-4.jpg', 'select number above cell'],
["val-image-1.jpg", "select calendar option"],
["val-image-1.jpg", "select photos&videos option"],
["val-image-2.jpg", "click on change store"],
["val-image-2.jpg", "click on shop menu at the bottom"],
["val-image-3.jpg", "click on image above short meow"],
["val-image-3.jpg", "go to cat sounds"],
["example_2.jpg", "click on green color button"],
["example_2.jpg", "click on text which is beside call now"],
["example_2.jpg", "click on more button"],
["example_2.jpg", "enter the text field next to the name"],
]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text"],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
# caching examples inference takes too long to start space after app change commit
cache_examples=False
)
demo.launch()
|