Spaces:
Sleeping
Sleeping
File size: 10,848 Bytes
1b85d75 7ba23b1 3798862 2399c69 7ba23b1 2399c69 c04ab4e 7ba23b1 3798862 597d5e5 2399c69 c04ab4e ec06acb 45a140a ec06acb 7643365 ec06acb 7643365 ec06acb 7643365 ec06acb 7643365 ec06acb 7643365 ec06acb 7643365 ec06acb 7643365 ec06acb 1b85d75 3798862 1b85d75 9cdce98 ccf9890 c04ab4e 2399c69 1b85d75 45a140a 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 246389f 1b85d75 45a140a 246389f 1b85d75 45a140a 1b85d75 7643365 ec06acb 45a140a ec06acb 45a140a 246389f 1b85d75 246389f 1b85d75 7ba23b1 3798862 7ba23b1 3798862 246389f 1b85d75 ee5f1b4 f85a58b 1b85d75 119b1f4 134af2f 1b85d75 3798862 1b85d75 3798862 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel
global model, loaded_revision, processor, device
model = None
previous_revision = None
processor = None
device = None
loaded_revision = None
def load_model(pretrained_revision: str = 'main'):
global model, loaded_revision, processor, device
pretrained_repo_name = 'ivelin/donut-refexp-click'
# revision can be git commit hash, branch or tag
# use 'main' for latest revision
print(
f"Loading model checkpoint from repo: {pretrained_repo_name}, revision: {pretrained_revision}")
if processor is None or loaded_revision is None or loaded_revision != pretrained_revision:
loaded_revision = pretrained_revision
processor = DonutProcessor.from_pretrained(
pretrained_repo_name, revision=pretrained_revision) # , use_auth_token="...")
processor.image_processor.do_align_long_axis = False
# do not manipulate image size and position
processor.image_processor.do_resize = False
processor.image_processor.do_thumbnail = False
processor.image_processor.do_pad = False
# processor.image_processor.do_rescale = False
processor.image_processor.do_normalize = True
print(f'processor image size: {processor.image_processor.size}')
model = VisionEncoderDecoderModel.from_pretrained(
pretrained_repo_name, revision=pretrained_revision) # use_auth_token="...",
print(f'model checkpoint loaded')
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def prepare_image_for_encoder(image=None, output_image_size=None):
"""
First, resizes the input image to fill as much as possible of the output image size
while preserving aspect ratio. Positions the resized image at (0,0) and fills
the rest of the gap space in the output image with black(0).
Args:
image: PIL image
output_image_size: (width, height) tuple
"""
assert image is not None
assert output_image_size is not None
img2 = image.copy()
img2.thumbnail(output_image_size)
oimg = Image.new(mode=img2.mode, size=output_image_size, color=0)
oimg.paste(img2, box=(0, 0))
return oimg
def translate_point_coords_from_out_to_in(point=None, input_image_size=None, output_image_size=None):
"""
Convert relative prediction coordinates from resized encoder tensor image
to original input image size.
Args:
original_point: x, y coordinates of the point coordinates in [0..1] range in the original image
input_image_size: (width, height) tuple
output_image_size: (width, height) tuple
"""
assert point is not None
assert input_image_size is not None
assert output_image_size is not None
print(
f"point={point}, input_image_size={input_image_size}, output_image_size={output_image_size}")
input_width, input_height = input_image_size
output_width, output_height = output_image_size
ratio = min(output_width/input_width, output_height/input_height)
resized_height = int(input_height*ratio)
resized_width = int(input_width*ratio)
print(f'>>> resized_width={resized_width}')
print(f'>>> resized_height={resized_height}')
if resized_height == input_height and resized_width == input_width:
return
# translation of the relative positioning is only needed for dimentions that have padding
if resized_width < output_width:
# adjust for padding pixels
point['x'] *= (output_width / resized_width)
if resized_height < output_height:
# adjust for padding pixels
point['y'] *= (output_height / resized_height)
print(
f"translated point={point}, resized_image_size: {resized_width, resized_height}")
def process_refexp(image, prompt: str, model_revision: str = 'main', return_annotated_image: bool = True):
print(f"(image, prompt): {image}, {prompt}")
if not model_revision:
model_revision = 'main'
print(f"model checkpoint revision: {model_revision}")
load_model(model_revision)
# trim prompt to 80 characters and normalize to lowercase
prompt = prompt[:80].lower()
# prepare encoder inputs
out_size = (
processor.image_processor.size['width'], processor.image_processor.size['height'])
in_size = image.size
prepped_image = prepare_image_for_encoder(
image, output_image_size=out_size)
pixel_values = processor(prepped_image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_center>"
prompt = task_prompt.replace("{user_input}", prompt)
decoder_input_ids = processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
print(fr"predicted decoder sequence: {html.escape(sequence)}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, "")
# remove first task start token
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
print(
fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
seqjson = processor.token2json(sequence)
# safeguard in case predicted sequence does not include a target_center token
center_point = seqjson.get('target_center')
if center_point is None:
print(
f"predicted sequence has no target_center, seq:{sequence}")
center_point = {"x": 0, "y": 0}
return center_point
print(f"predicted center_point with text coordinates: {center_point}")
# safeguard in case text prediction is missing some center point coordinates
# or coordinates are not valid numeric values
try:
x = float(center_point.get("x", 0))
except ValueError:
x = 0
try:
y = float(center_point.get("y", 0))
except ValueError:
y = 0
# replace str with float coords
center_point = {"x": x, "y": y,
"decoder output sequence (before x,y adjustment)": sequence}
print(f"predicted center_point with float coordinates: {center_point}")
print(f"input image size: {in_size}")
print(f"processed prompt: {prompt}")
# convert coordinates from tensor image size to input image size
out_size = (
processor.image_processor.size['width'], processor.image_processor.size['height'])
translate_point_coords_from_out_to_in(
point=center_point, input_image_size=in_size, output_image_size=out_size)
width, height = in_size
x = math.floor(width*center_point["x"])
y = math.floor(height*center_point["y"])
print(
f"to image pixel values: x, y: {x, y}")
if return_annotated_image:
# draw center point circle
img1 = ImageDraw.Draw(image)
r = 30
shape = [(x-r, y-r), (x+r, y+r)]
img1.ellipse(shape, outline="green", width=20)
img1.ellipse(shape, outline="white", width=10)
else:
# do not return image if its an API call to save bandwidth
image = None
return image, center_point
title = "Demo: GuardianUI RefExp Click"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. Optionally enter value for model git revision; latest checkpoint will be used by default."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the menu icon right of cloud icon at the top", "", True],
["example_1.jpg", "click on down arrow beside the entertainment", "", True],
["example_1.jpg", "select the down arrow button beside lifestyle", "", True],
["example_1.jpg", "click on the image beside the option traffic", "", True],
["example_3.jpg", "select the third row first image", "", True],
["example_3.jpg", "click the tick mark on the first image", "", True],
["example_3.jpg", "select the ninth image", "", True],
["example_3.jpg", "select the add icon", "", True],
["example_3.jpg", "click the first image", "", True],
["val-image-4.jpg", 'select 4153365454', "", True],
['val-image-4.jpg', 'go to cell', "", True],
['val-image-4.jpg', 'select number above cell', "", True],
["val-image-1.jpg", "select calendar option", "", True],
["val-image-1.jpg", "select photos&videos option", "", True],
["val-image-2.jpg", "click on change store", "", True],
["val-image-2.jpg", "click on shop menu at the bottom", "", True],
["val-image-3.jpg", "click on image above short meow", "", True],
["val-image-3.jpg", "go to cat sounds", "", True],
["example_2.jpg", "click on green color button", "", True],
["example_2.jpg", "click on text which is beside call now", "", True],
["example_2.jpg", "click on more button", "", True],
["example_2.jpg", "enter the text field next to the name", "", True],
]
demo = gr.Interface(fn=process_refexp,
inputs=[gr.Image(type="pil"), "text", "text", gr.Checkbox(
value=True, label="Return Annotated Image", visible=False)],
outputs=[gr.Image(type="pil"), "json"],
title=title,
description=description,
article=article,
examples=examples,
# caching examples inference takes too long to start space after app change commit
cache_examples=False
)
# share=True when running in a Jupyter Notebook
demo.launch(server_name="0.0.0.0")
|