abiyyufahri commited on
Commit
6b36184
·
verified ·
1 Parent(s): 89b8ede

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -141
app.py CHANGED
@@ -1,13 +1,14 @@
1
- import os
2
- import spaces
3
- import gradio as gr
4
  from PIL import Image
 
 
5
  import torch
6
  import re
7
  import logging
8
- from typing import Tuple, List
9
- import base64
10
- from io import BytesIO
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
@@ -20,7 +21,7 @@ tokenizer = None
20
  model_name = "microsoft/GUI-Actor-2B-Qwen2-VL"
21
  model_loaded = False
22
 
23
- def load_model():
24
  """Load model with proper error handling and fallback strategies"""
25
  global model, processor, tokenizer, model_loaded
26
 
@@ -39,8 +40,8 @@ def load_model():
39
 
40
  model = Qwen2VLForConditionalGeneration.from_pretrained(
41
  model_name,
42
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
43
- device_map="auto" if torch.cuda.is_available() else None,
44
  trust_remote_code=True,
45
  low_cpu_mem_usage=True
46
  ).eval()
@@ -52,17 +53,17 @@ def load_model():
52
  logger.info("Trying AutoProcessor and AutoModel fallback...")
53
 
54
  try:
55
- from transformers import AutoProcessor, AutoModelForVision2Seq
56
 
57
  processor = AutoProcessor.from_pretrained(
58
  model_name,
59
  trust_remote_code=True
60
  )
61
 
62
- model = AutoModelForVision2Seq.from_pretrained(
63
  model_name,
64
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
65
- device_map="auto" if torch.cuda.is_available() else None,
66
  trust_remote_code=True,
67
  low_cpu_mem_usage=True
68
  ).eval()
@@ -74,7 +75,7 @@ def load_model():
74
  logger.info("Trying generic transformers approach...")
75
 
76
  # Last fallback - try loading as generic model
77
- from transformers import AutoConfig, AutoProcessor
78
  import transformers
79
 
80
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -96,8 +97,8 @@ def load_model():
96
  model = ModelClass.from_pretrained(
97
  model_name,
98
  config=config,
99
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
100
- device_map="auto" if torch.cuda.is_available() else None,
101
  trust_remote_code=True,
102
  low_cpu_mem_usage=True
103
  ).eval()
@@ -116,8 +117,30 @@ def load_model():
116
  model_loaded = False
117
  return False
118
 
119
- def extract_coordinates(text: str) -> List[Tuple[float, float]]:
120
- """Extract coordinates from model output text"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Pattern untuk mencari koordinat dalam berbagai format
122
  patterns = [
123
  r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y)
@@ -143,38 +166,11 @@ def extract_coordinates(text: str) -> List[Tuple[float, float]]:
143
  # Default ke center jika tidak ditemukan
144
  return [(0.5, 0.5)]
145
 
146
- @spaces.GPU # Decorator untuk menggunakan GPU di Hugging Face Spaces
147
- def inference(pil_image: Image.Image, instruction: str):
148
- """Inference function with Spaces GPU support"""
149
- if not model_loaded:
150
- return "Model not loaded properly", 0.5, 0.5
151
-
152
  try:
153
- conversation = [
154
- {
155
- "role": "system",
156
- "content": [
157
- {
158
- "type": "text",
159
- "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.",
160
- }
161
- ]
162
- },
163
- {
164
- "role": "user",
165
- "content": [
166
- {
167
- "type": "image",
168
- "image": pil_image,
169
- },
170
- {
171
- "type": "text",
172
- "text": instruction,
173
- },
174
- ],
175
- },
176
- ]
177
-
178
  # Apply chat template
179
  text = processor.apply_chat_template(
180
  conversation,
@@ -190,15 +186,11 @@ def inference(pil_image: Image.Image, instruction: str):
190
  text=[text],
191
  images=[image],
192
  return_tensors="pt",
193
- padding=True,
194
- truncation=True,
195
- max_length=512
196
  )
197
 
198
- # Move inputs to the same device as model
199
- if torch.cuda.is_available():
200
- inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
201
-
202
  # Generate response with proper error handling
203
  with torch.no_grad():
204
  try:
@@ -226,97 +218,119 @@ def inference(pil_image: Image.Image, instruction: str):
226
 
227
  # Extract coordinates
228
  coordinates = extract_coordinates(response)
229
- px, py = coordinates[0]
230
 
231
- return response, round(px, 4), round(py, 4)
 
 
 
 
232
 
233
  except Exception as e:
234
  logger.error(f"Inference error: {e}")
235
- return f"Error during inference: {str(e)}", 0.5, 0.5
 
 
 
 
236
 
237
- def process_image(image: Image.Image, instruction: str):
238
- """Process the uploaded image and instruction"""
239
- if image is None:
240
- return "Please upload an image", 0.5, 0.5
241
-
242
- if not instruction.strip():
243
- return "Please provide an instruction", 0.5, 0.5
244
-
245
- # Convert image to RGB if needed
246
- if image.mode != "RGB":
247
- image = image.convert("RGB")
248
-
249
- # Run inference
250
- response, x, y = inference(image, instruction)
 
 
251
 
252
- return response, x, y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- # Load model on startup
255
- logger.info("Loading model...")
256
- load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- # Create Gradio interface
259
- with gr.Blocks(title="GUI-Actor Click Prediction", theme=gr.themes.Soft()) as demo:
260
- gr.Markdown("# GUI-Actor Click Prediction")
261
- gr.Markdown("Upload a screenshot and provide instructions to get click coordinates prediction.")
262
-
263
- with gr.Row():
264
- with gr.Column():
265
- image_input = gr.Image(
266
- type="pil",
267
- label="Upload Screenshot",
268
- height=400
269
- )
270
- instruction_input = gr.Textbox(
271
- label="Instruction",
272
- placeholder="e.g., Click on the login button",
273
- lines=3
274
- )
275
- submit_btn = gr.Button("Predict Click Location", variant="primary")
276
-
277
- with gr.Column():
278
- response_output = gr.Textbox(
279
- label="Model Response",
280
- lines=5,
281
- interactive=False
282
- )
283
- with gr.Row():
284
- x_output = gr.Number(
285
- label="X Coordinate (normalized)",
286
- precision=4,
287
- interactive=False
288
- )
289
- y_output = gr.Number(
290
- label="Y Coordinate (normalized)",
291
- precision=4,
292
- interactive=False
293
- )
294
-
295
- # Status indicator
296
- with gr.Row():
297
- gr.Markdown(f"**Model Status:** {'✅ Loaded' if model_loaded else '❌ Not Loaded'}")
298
- gr.Markdown(f"**Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}")
299
-
300
- # Examples
301
- gr.Examples(
302
- examples=[
303
- ["Click on the search button", None],
304
- ["Select the dropdown menu", None],
305
- ["Click on the submit form", None],
306
- ],
307
- inputs=[instruction_input, image_input],
308
- )
309
-
310
- # Event handlers
311
- submit_btn.click(
312
- fn=process_image,
313
- inputs=[image_input, instruction_input],
314
- outputs=[response_output, x_output, y_output]
315
- )
316
 
317
- if __name__ == "__main__":
318
- demo.launch(
319
- server_name="0.0.0.0",
320
- server_port=7860,
321
- share=True
322
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
  from PIL import Image
5
+ from io import BytesIO
6
+ import base64
7
  import torch
8
  import re
9
  import logging
10
+ import asyncio
11
+ from contextlib import asynccontextmanager
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
 
21
  model_name = "microsoft/GUI-Actor-2B-Qwen2-VL"
22
  model_loaded = False
23
 
24
+ async def load_model():
25
  """Load model with proper error handling and fallback strategies"""
26
  global model, processor, tokenizer, model_loaded
27
 
 
40
 
41
  model = Qwen2VLForConditionalGeneration.from_pretrained(
42
  model_name,
43
+ torch_dtype=torch.float32,
44
+ device_map=None, # CPU only
45
  trust_remote_code=True,
46
  low_cpu_mem_usage=True
47
  ).eval()
 
53
  logger.info("Trying AutoProcessor and AutoModel fallback...")
54
 
55
  try:
56
+ from transformers import AutoProcessor, AutoModel
57
 
58
  processor = AutoProcessor.from_pretrained(
59
  model_name,
60
  trust_remote_code=True
61
  )
62
 
63
+ model = AutoModel.from_pretrained(
64
  model_name,
65
+ torch_dtype=torch.float32,
66
+ device_map=None,
67
  trust_remote_code=True,
68
  low_cpu_mem_usage=True
69
  ).eval()
 
75
  logger.info("Trying generic transformers approach...")
76
 
77
  # Last fallback - try loading as generic model
78
+ from transformers import AutoConfig, AutoTokenizer
79
  import transformers
80
 
81
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
 
97
  model = ModelClass.from_pretrained(
98
  model_name,
99
  config=config,
100
+ torch_dtype=torch.float32,
101
+ device_map=None,
102
  trust_remote_code=True,
103
  low_cpu_mem_usage=True
104
  ).eval()
 
117
  model_loaded = False
118
  return False
119
 
120
+ @asynccontextmanager
121
+ async def lifespan(app: FastAPI):
122
+ # Startup
123
+ logger.info("Starting up GUI-Actor API...")
124
+ await load_model()
125
+ yield
126
+ # Shutdown
127
+ logger.info("Shutting down GUI-Actor API...")
128
+
129
+ # Initialize FastAPI app with lifespan
130
+ app = FastAPI(
131
+ title="GUI-Actor API",
132
+ version="1.0.0",
133
+ lifespan=lifespan
134
+ )
135
+
136
+ class Base64Request(BaseModel):
137
+ image_base64: str
138
+ instruction: str
139
+
140
+ def extract_coordinates(text):
141
+ """
142
+ Extract coordinates from model output text
143
+ """
144
  # Pattern untuk mencari koordinat dalam berbagai format
145
  patterns = [
146
  r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y)
 
166
  # Default ke center jika tidak ditemukan
167
  return [(0.5, 0.5)]
168
 
169
+ def cpu_inference(conversation, model, tokenizer, processor):
170
+ """
171
+ Inference function untuk CPU with better error handling
172
+ """
 
 
173
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # Apply chat template
175
  text = processor.apply_chat_template(
176
  conversation,
 
186
  text=[text],
187
  images=[image],
188
  return_tensors="pt",
189
+ padding=True, # Enable padding
190
+ truncation=True, # Enable truncation for long texts
191
+ max_length=512 # Set reasonable max length
192
  )
193
 
 
 
 
 
194
  # Generate response with proper error handling
195
  with torch.no_grad():
196
  try:
 
218
 
219
  # Extract coordinates
220
  coordinates = extract_coordinates(response)
 
221
 
222
+ return {
223
+ "topk_points": coordinates,
224
+ "response": response,
225
+ "success": True
226
+ }
227
 
228
  except Exception as e:
229
  logger.error(f"Inference error: {e}")
230
+ return {
231
+ "topk_points": [(0.5, 0.5)],
232
+ "response": f"Error during inference: {str(e)}",
233
+ "success": False
234
+ }
235
 
236
+ @app.get("/")
237
+ async def root():
238
+ return {
239
+ "message": "GUI-Actor API is running",
240
+ "status": "healthy",
241
+ "model_loaded": model_loaded,
242
+ "model_name": model_name
243
+ }
244
+
245
+ @app.post("/click/base64")
246
+ async def predict_click_base64(data: Base64Request):
247
+ if not model_loaded:
248
+ raise HTTPException(
249
+ status_code=503,
250
+ detail="Model not loaded properly"
251
+ )
252
 
253
+ try:
254
+ # Decode base64 to image
255
+ try:
256
+ # Handle data URL format
257
+ if "," in data.image_base64:
258
+ image_data = base64.b64decode(data.image_base64.split(",")[-1])
259
+ else:
260
+ image_data = base64.b64decode(data.image_base64)
261
+ except Exception as e:
262
+ raise HTTPException(status_code=400, detail=f"Invalid base64 image: {e}")
263
+
264
+ try:
265
+ pil_image = Image.open(BytesIO(image_data)).convert("RGB")
266
+ except Exception as e:
267
+ raise HTTPException(status_code=400, detail=f"Invalid image format: {e}")
268
 
269
+ conversation = [
270
+ {
271
+ "role": "system",
272
+ "content": [
273
+ {
274
+ "type": "text",
275
+ "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.",
276
+ }
277
+ ]
278
+ },
279
+ {
280
+ "role": "user",
281
+ "content": [
282
+ {
283
+ "type": "image",
284
+ "image": pil_image,
285
+ },
286
+ {
287
+ "type": "text",
288
+ "text": data.instruction,
289
+ },
290
+ ],
291
+ },
292
+ ]
293
 
294
+ # Run inference
295
+ pred = cpu_inference(conversation, model, tokenizer, processor)
296
+ px, py = pred["topk_points"][0]
297
+
298
+ return JSONResponse(content={
299
+ "x": round(px, 4),
300
+ "y": round(py, 4),
301
+ "response": pred["response"],
302
+ "success": pred["success"]
303
+ })
304
+
305
+ except HTTPException:
306
+ raise
307
+ except Exception as e:
308
+ logger.error(f"Prediction error: {e}")
309
+ raise HTTPException(
310
+ status_code=500,
311
+ detail=f"Internal server error: {str(e)}"
312
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ @app.get("/health")
315
+ async def health_check():
316
+ return {
317
+ "status": "healthy" if model_loaded else "unhealthy",
318
+ "model": model_name,
319
+ "device": "cpu",
320
+ "torch_dtype": "float32",
321
+ "model_loaded": model_loaded
322
+ }
323
+
324
+ @app.get("/debug")
325
+ async def debug_info():
326
+ """Debug endpoint to check model loading status"""
327
+ import transformers
328
+ available_classes = [attr for attr in dir(transformers) if 'Qwen' in attr or 'VL' in attr]
329
+
330
+ return {
331
+ "model_loaded": model_loaded,
332
+ "processor_type": type(processor).__name__ if processor else None,
333
+ "model_type": type(model).__name__ if model else None,
334
+ "available_qwen_classes": available_classes,
335
+ "transformers_version": transformers.__version__
336
+ }