programmnix-askui commited on
Commit
96cec35
Β·
1 Parent(s): 46d803c

Add OS-Atlas

Browse files
Files changed (1) hide show
  1. app.py +132 -7
app.py CHANGED
@@ -1,14 +1,139 @@
1
  import gradio as gr
2
  import spaces
 
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' πŸ€”
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' πŸ€—
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
  import torch
6
+ import base64
7
+ from PIL import Image, ImageDraw
8
+ from io import BytesIO
9
+ import re
10
+
11
+
12
+ models = {
13
+ "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained("OS-Copilot/OS-Atlas-Base-7B", torch_dtype="auto", device_map="auto"),
14
+ }
15
+
16
+ processors = {
17
+ "OS-Copilot/OS-Atlas-Base-7B": AutoProcessor.from_pretrained("OS-Copilot/OS-Atlas-Base-7B")
18
+ }
19
+
20
+
21
+ def image_to_base64(image):
22
+ buffered = BytesIO()
23
+ image.save(buffered, format="PNG")
24
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
25
+ return img_str
26
+
27
+
28
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
29
+ draw = ImageDraw.Draw(image)
30
+ for box in bounding_boxes:
31
+ xmin, ymin, xmax, ymax = box
32
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
33
+ return image
34
+
35
+
36
+ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
37
+ x_scale = original_width / scaled_width
38
+ y_scale = original_height / scaled_height
39
+ rescaled_boxes = []
40
+ for box in bounding_boxes:
41
+ xmin, ymin, xmax, ymax = box
42
+ rescaled_box = [
43
+ xmin * x_scale,
44
+ ymin * y_scale,
45
+ xmax * x_scale,
46
+ ymax * y_scale
47
+ ]
48
+ rescaled_boxes.append(rescaled_box)
49
+ return rescaled_boxes
50
 
 
 
51
 
52
  @spaces.GPU
53
+ def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
54
+ model = models[model_id].eval()
55
+ processor = processors[model_id]
56
+ prompt = f"In this UI screenshot, what is the position of the element corresponding to the command \"{text_input}\" (with bbox)?"
57
+ messages = [
58
+ {
59
+ "role": "user",
60
+ "content": [
61
+ {"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
62
+ {"type": "text", "text": prompt},
63
+ ],
64
+ }
65
+ ]
66
+
67
+ text = processor.apply_chat_template(
68
+ messages, tokenize=False, add_generation_prompt=True
69
+ )
70
+ image_inputs, video_inputs = process_vision_info(messages)
71
+ inputs = processor(
72
+ text=[text],
73
+ images=image_inputs,
74
+ videos=video_inputs,
75
+ padding=True,
76
+ return_tensors="pt",
77
+ )
78
+ inputs = inputs.to("cuda")
79
+
80
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
81
+ generated_ids_trimmed = [
82
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
83
+ ]
84
+ output_text = processor.batch_decode(
85
+ generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
86
+ )
87
+ print(output_text)
88
+ text = output_text[0]
89
+
90
+ object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
91
+ box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
92
+
93
+ object_ref = re.search(object_ref_pattern, text).group(1)
94
+ box_content = re.search(box_pattern, text).group(1)
95
+
96
+ boxes = [tuple(map(int, pair.strip("()").split(','))) for pair in box_content.split("),(")]
97
+ boxes = [[boxes[0][0], boxes[0][1], boxes[1][0], boxes[1][1]]]
98
+
99
+ scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height)
100
+ return object_ref, scaled_boxes, draw_bounding_boxes(image, scaled_boxes)
101
+
102
+ css = """
103
+ #output {
104
+ height: 500px;
105
+ overflow: auto;
106
+ border: 1px solid #ccc;
107
+ }
108
+ """
109
+ with gr.Blocks(css=css) as demo:
110
+ gr.Markdown(
111
+ """
112
+ # Demo for OS-ATLAS: A Foundation Action Model For Generalist GUI Agents
113
+ """)
114
+ with gr.Row():
115
+ with gr.Column():
116
+ input_img = gr.Image(label="Input Image", type="pil")
117
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="OS-Copilot/OS-Atlas-Base-7B")
118
+ text_input = gr.Textbox(label="User Prompt")
119
+ submit_btn = gr.Button(value="Submit")
120
+ with gr.Column():
121
+ model_output_text = gr.Textbox(label="Model Output Text")
122
+ model_output_box = gr.Textbox(label="Model Output Box")
123
+ annotated_image = gr.Image(label="Annotated Image")
124
+
125
+ gr.Examples(
126
+ examples=[
127
+ ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "select search textfield"],
128
+ ["assets/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png", "switch to discussions"],
129
+ ],
130
+ inputs=[input_img, text_input],
131
+ outputs=[model_output_text, model_output_box, annotated_image],
132
+ fn=run_example,
133
+ cache_examples=True,
134
+ label="Try examples"
135
+ )
136
+
137
+ submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, model_output_box, annotated_image])
138
 
139
+ demo.launch(debug=True)