File size: 10,848 Bytes
1b85d75
 
 
 
 
 
 
 
 
7ba23b1
 
3798862
 
 
 
2399c69
 
 
7ba23b1
2399c69
 
 
c04ab4e
 
7ba23b1
3798862
597d5e5
 
 
 
 
 
 
 
 
 
 
 
 
2399c69
 
c04ab4e
ec06acb
 
 
 
 
 
 
 
 
 
 
 
45a140a
 
 
 
ec06acb
 
 
7643365
 
 
 
 
ec06acb
7643365
 
ec06acb
7643365
 
 
ec06acb
 
7643365
 
ec06acb
7643365
ec06acb
7643365
 
ec06acb
 
7643365
 
 
 
 
 
 
 
 
 
 
ec06acb
 
 
1b85d75
3798862
1b85d75
 
9cdce98
ccf9890
 
c04ab4e
2399c69
 
 
1b85d75
 
 
 
 
45a140a
 
 
 
 
 
1b85d75
 
246389f
1b85d75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246389f
 
 
1b85d75
246389f
 
 
1b85d75
246389f
 
1b85d75
 
246389f
1b85d75
246389f
1b85d75
246389f
1b85d75
246389f
1b85d75
45a140a
 
246389f
1b85d75
45a140a
1b85d75
 
7643365
ec06acb
 
 
45a140a
ec06acb
45a140a
246389f
 
1b85d75
 
246389f
1b85d75
7ba23b1
3798862
 
 
 
 
 
7ba23b1
3798862
 
246389f
 
1b85d75
 
ee5f1b4
f85a58b
1b85d75
119b1f4
134af2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b85d75
 
 
3798862
 
1b85d75
 
 
 
 
 
 
 
 
3798862
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel


global model, loaded_revision, processor, device
model = None
previous_revision = None
processor = None
device = None
loaded_revision = None


def load_model(pretrained_revision: str = 'main'):
    global model, loaded_revision, processor, device
    pretrained_repo_name = 'ivelin/donut-refexp-click'
    # revision can be git commit hash, branch or tag
    # use 'main' for latest revision
    print(
        f"Loading model checkpoint from repo: {pretrained_repo_name}, revision: {pretrained_revision}")
    if processor is None or loaded_revision is None or loaded_revision != pretrained_revision:
        loaded_revision = pretrained_revision
        processor = DonutProcessor.from_pretrained(
            pretrained_repo_name, revision=pretrained_revision)  # , use_auth_token="...")
        processor.image_processor.do_align_long_axis = False
        # do not manipulate image size and position
        processor.image_processor.do_resize = False
        processor.image_processor.do_thumbnail = False
        processor.image_processor.do_pad = False
        # processor.image_processor.do_rescale = False
        processor.image_processor.do_normalize = True
        print(f'processor image size: {processor.image_processor.size}')
        model = VisionEncoderDecoderModel.from_pretrained(
            pretrained_repo_name, revision=pretrained_revision)  # use_auth_token="...",
        print(f'model checkpoint loaded')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)


def prepare_image_for_encoder(image=None, output_image_size=None):
    """
    First, resizes the input image to fill as much as possible of the output image size
    while preserving aspect ratio. Positions the resized image at (0,0) and fills
    the rest of the gap space in the output image with black(0).
    Args:
        image: PIL image
        output_image_size: (width, height) tuple
    """
    assert image is not None
    assert output_image_size is not None
    img2 = image.copy()
    img2.thumbnail(output_image_size)
    oimg = Image.new(mode=img2.mode, size=output_image_size, color=0)
    oimg.paste(img2, box=(0, 0))
    return oimg


def translate_point_coords_from_out_to_in(point=None, input_image_size=None, output_image_size=None):
    """
    Convert relative prediction coordinates from resized encoder tensor image
    to original input image size.
    Args:
        original_point: x, y coordinates of the point coordinates in [0..1] range in the original image
        input_image_size: (width, height) tuple
        output_image_size: (width, height) tuple
    """
    assert point is not None
    assert input_image_size is not None
    assert output_image_size is not None
    print(
        f"point={point}, input_image_size={input_image_size}, output_image_size={output_image_size}")
    input_width, input_height = input_image_size
    output_width, output_height = output_image_size

    ratio = min(output_width/input_width, output_height/input_height)

    resized_height = int(input_height*ratio)
    resized_width = int(input_width*ratio)
    print(f'>>> resized_width={resized_width}')
    print(f'>>> resized_height={resized_height}')

    if resized_height == input_height and resized_width == input_width:
        return

    # translation of the relative positioning is only needed for dimentions that have padding
    if resized_width < output_width:
        # adjust for padding pixels
        point['x'] *= (output_width / resized_width)
    if resized_height < output_height:
        # adjust for padding pixels
        point['y'] *= (output_height / resized_height)
    print(
        f"translated point={point}, resized_image_size: {resized_width, resized_height}")


def process_refexp(image, prompt: str, model_revision: str = 'main', return_annotated_image: bool = True):

    print(f"(image, prompt): {image}, {prompt}")

    if not model_revision:
        model_revision = 'main'

    print(f"model checkpoint revision: {model_revision}")

    load_model(model_revision)

    # trim prompt to 80 characters and normalize to lowercase
    prompt = prompt[:80].lower()

    # prepare encoder inputs
    out_size = (
        processor.image_processor.size['width'], processor.image_processor.size['height'])
    in_size = image.size
    prepped_image = prepare_image_for_encoder(
        image, output_image_size=out_size)
    pixel_values = processor(prepped_image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_center>"
    prompt = task_prompt.replace("{user_input}", prompt)
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(fr"predicted decoder sequence: {html.escape(sequence)}")
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    # remove first task start token
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    print(
        fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
    seqjson = processor.token2json(sequence)

    # safeguard in case predicted sequence does not include a target_center token
    center_point = seqjson.get('target_center')
    if center_point is None:
        print(
            f"predicted sequence has no target_center, seq:{sequence}")
        center_point = {"x": 0, "y": 0}
        return center_point

    print(f"predicted center_point with text coordinates: {center_point}")
    # safeguard in case text prediction is missing some center point coordinates
    # or coordinates are not valid numeric values
    try:
        x = float(center_point.get("x", 0))
    except ValueError:
        x = 0
    try:
        y = float(center_point.get("y", 0))
    except ValueError:
        y = 0
    # replace str with float coords
    center_point = {"x": x, "y": y,
                    "decoder output sequence (before x,y adjustment)": sequence}
    print(f"predicted center_point with float coordinates: {center_point}")

    print(f"input image size: {in_size}")
    print(f"processed prompt: {prompt}")

    # convert coordinates from tensor image size to input image size
    out_size = (
        processor.image_processor.size['width'], processor.image_processor.size['height'])
    translate_point_coords_from_out_to_in(
        point=center_point, input_image_size=in_size, output_image_size=out_size)

    width, height = in_size
    x = math.floor(width*center_point["x"])
    y = math.floor(height*center_point["y"])

    print(
        f"to image pixel values: x, y: {x, y}")

    if return_annotated_image:
        # draw center point circle
        img1 = ImageDraw.Draw(image)
        r = 30
        shape = [(x-r, y-r), (x+r, y+r)]
        img1.ellipse(shape, outline="green", width=20)
        img1.ellipse(shape, outline="white", width=10)
    else:
        # do not return image if its an API call to save bandwidth
        image = None

    return image, center_point


title = "Demo: GuardianUI RefExp Click"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. Optionally enter value for model git revision; latest checkpoint will be used by default."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the menu icon right of cloud icon at the top", "", True],
            ["example_1.jpg", "click on down arrow beside the entertainment", "", True],
            ["example_1.jpg", "select the down arrow button beside lifestyle", "", True],
            ["example_1.jpg", "click on the image beside the option traffic", "", True],
            ["example_3.jpg", "select the third row first image", "", True],
            ["example_3.jpg", "click the tick mark on the first image", "", True],
            ["example_3.jpg", "select the ninth image", "", True],
            ["example_3.jpg", "select the add icon", "", True],
            ["example_3.jpg", "click the first image", "", True],
            ["val-image-4.jpg", 'select 4153365454', "", True],
            ['val-image-4.jpg', 'go to cell', "", True],
            ['val-image-4.jpg', 'select number above cell', "", True],
            ["val-image-1.jpg", "select calendar option", "", True],
            ["val-image-1.jpg", "select photos&videos option", "", True],
            ["val-image-2.jpg", "click on change store", "", True],
            ["val-image-2.jpg", "click on shop menu at the bottom", "", True],
            ["val-image-3.jpg", "click on image above short meow", "", True],
            ["val-image-3.jpg", "go to cat sounds", "", True],
            ["example_2.jpg", "click on green color button", "", True],
            ["example_2.jpg", "click on text which is beside call now", "", True],
            ["example_2.jpg", "click on more button", "", True],
            ["example_2.jpg", "enter the text field next to the name", "", True],
            ]

demo = gr.Interface(fn=process_refexp,
                    inputs=[gr.Image(type="pil"), "text", "text", gr.Checkbox(
                        value=True, label="Return Annotated Image", visible=False)],
                    outputs=[gr.Image(type="pil"), "json"],
                    title=title,
                    description=description,
                    article=article,
                    examples=examples,
                    # caching examples inference takes too long to start space after app change commit
                    cache_examples=False
                    )

# share=True when running in a Jupyter Notebook
demo.launch(server_name="0.0.0.0")