abiyyufahri commited on
Commit
563f88c
·
verified ·
1 Parent(s): 6f88ac4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -155
app.py CHANGED
@@ -1,14 +1,13 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel
4
  from PIL import Image
5
- from io import BytesIO
6
- import base64
7
  import torch
8
  import re
9
  import logging
10
- import asyncio
11
- from contextlib import asynccontextmanager
 
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
@@ -21,7 +20,7 @@ tokenizer = None
21
  model_name = "microsoft/GUI-Actor-2B-Qwen2-VL"
22
  model_loaded = False
23
 
24
- async def load_model():
25
  """Load model with proper error handling and fallback strategies"""
26
  global model, processor, tokenizer, model_loaded
27
 
@@ -40,8 +39,8 @@ async def load_model():
40
 
41
  model = Qwen2VLForConditionalGeneration.from_pretrained(
42
  model_name,
43
- torch_dtype=torch.float32,
44
- device_map=None, # CPU only
45
  trust_remote_code=True,
46
  low_cpu_mem_usage=True
47
  ).eval()
@@ -53,17 +52,17 @@ async def load_model():
53
  logger.info("Trying AutoProcessor and AutoModel fallback...")
54
 
55
  try:
56
- from transformers import AutoProcessor, AutoModel
57
 
58
  processor = AutoProcessor.from_pretrained(
59
  model_name,
60
  trust_remote_code=True
61
  )
62
 
63
- model = AutoModel.from_pretrained(
64
  model_name,
65
- torch_dtype=torch.float32,
66
- device_map=None,
67
  trust_remote_code=True,
68
  low_cpu_mem_usage=True
69
  ).eval()
@@ -75,7 +74,7 @@ async def load_model():
75
  logger.info("Trying generic transformers approach...")
76
 
77
  # Last fallback - try loading as generic model
78
- from transformers import AutoConfig, AutoTokenizer
79
  import transformers
80
 
81
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -97,8 +96,8 @@ async def load_model():
97
  model = ModelClass.from_pretrained(
98
  model_name,
99
  config=config,
100
- torch_dtype=torch.float32,
101
- device_map=None,
102
  trust_remote_code=True,
103
  low_cpu_mem_usage=True
104
  ).eval()
@@ -117,30 +116,8 @@ async def load_model():
117
  model_loaded = False
118
  return False
119
 
120
- @asynccontextmanager
121
- async def lifespan(app: FastAPI):
122
- # Startup
123
- logger.info("Starting up GUI-Actor API...")
124
- await load_model()
125
- yield
126
- # Shutdown
127
- logger.info("Shutting down GUI-Actor API...")
128
-
129
- # Initialize FastAPI app with lifespan
130
- app = FastAPI(
131
- title="GUI-Actor API",
132
- version="1.0.0",
133
- lifespan=lifespan
134
- )
135
-
136
- class Base64Request(BaseModel):
137
- image_base64: str
138
- instruction: str
139
-
140
- def extract_coordinates(text):
141
- """
142
- Extract coordinates from model output text
143
- """
144
  # Pattern untuk mencari koordinat dalam berbagai format
145
  patterns = [
146
  r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y)
@@ -166,11 +143,38 @@ def extract_coordinates(text):
166
  # Default ke center jika tidak ditemukan
167
  return [(0.5, 0.5)]
168
 
169
- def cpu_inference(conversation, model, tokenizer, processor):
170
- """
171
- Inference function untuk CPU with better error handling
172
- """
 
 
173
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # Apply chat template
175
  text = processor.apply_chat_template(
176
  conversation,
@@ -186,11 +190,15 @@ def cpu_inference(conversation, model, tokenizer, processor):
186
  text=[text],
187
  images=[image],
188
  return_tensors="pt",
189
- padding=True, # Enable padding
190
- truncation=True, # Enable truncation for long texts
191
- max_length=512 # Set reasonable max length
192
  )
193
 
 
 
 
 
194
  # Generate response with proper error handling
195
  with torch.no_grad():
196
  try:
@@ -218,119 +226,97 @@ def cpu_inference(conversation, model, tokenizer, processor):
218
 
219
  # Extract coordinates
220
  coordinates = extract_coordinates(response)
 
221
 
222
- return {
223
- "topk_points": coordinates,
224
- "response": response,
225
- "success": True
226
- }
227
 
228
  except Exception as e:
229
  logger.error(f"Inference error: {e}")
230
- return {
231
- "topk_points": [(0.5, 0.5)],
232
- "response": f"Error during inference: {str(e)}",
233
- "success": False
234
- }
235
-
236
- @app.get("/")
237
- async def root():
238
- return {
239
- "message": "GUI-Actor API is running",
240
- "status": "healthy",
241
- "model_loaded": model_loaded,
242
- "model_name": model_name
243
- }
244
 
245
- @app.post("/click/base64")
246
- async def predict_click_base64(data: Base64Request):
247
- if not model_loaded:
248
- raise HTTPException(
249
- status_code=503,
250
- detail="Model not loaded properly"
251
- )
252
 
253
- try:
254
- # Decode base64 to image
255
- try:
256
- # Handle data URL format
257
- if "," in data.image_base64:
258
- image_data = base64.b64decode(data.image_base64.split(",")[-1])
259
- else:
260
- image_data = base64.b64decode(data.image_base64)
261
- except Exception as e:
262
- raise HTTPException(status_code=400, detail=f"Invalid base64 image: {e}")
263
-
264
- try:
265
- pil_image = Image.open(BytesIO(image_data)).convert("RGB")
266
- except Exception as e:
267
- raise HTTPException(status_code=400, detail=f"Invalid image format: {e}")
268
-
269
- conversation = [
270
- {
271
- "role": "system",
272
- "content": [
273
- {
274
- "type": "text",
275
- "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.",
276
- }
277
- ]
278
- },
279
- {
280
- "role": "user",
281
- "content": [
282
- {
283
- "type": "image",
284
- "image": pil_image,
285
- },
286
- {
287
- "type": "text",
288
- "text": data.instruction,
289
- },
290
- ],
291
- },
292
- ]
293
-
294
- # Run inference
295
- pred = cpu_inference(conversation, model, tokenizer, processor)
296
- px, py = pred["topk_points"][0]
297
-
298
- return JSONResponse(content={
299
- "x": round(px, 4),
300
- "y": round(py, 4),
301
- "response": pred["response"],
302
- "success": pred["success"]
303
- })
304
-
305
- except HTTPException:
306
- raise
307
- except Exception as e:
308
- logger.error(f"Prediction error: {e}")
309
- raise HTTPException(
310
- status_code=500,
311
- detail=f"Internal server error: {str(e)}"
312
- )
313
 
314
- @app.get("/health")
315
- async def health_check():
316
- return {
317
- "status": "healthy" if model_loaded else "unhealthy",
318
- "model": model_name,
319
- "device": "cpu",
320
- "torch_dtype": "float32",
321
- "model_loaded": model_loaded
322
- }
323
 
324
- @app.get("/debug")
325
- async def debug_info():
326
- """Debug endpoint to check model loading status"""
327
- import transformers
328
- available_classes = [attr for attr in dir(transformers) if 'Qwen' in attr or 'VL' in attr]
329
 
330
- return {
331
- "model_loaded": model_loaded,
332
- "processor_type": type(processor).__name__ if processor else None,
333
- "model_type": type(model).__name__ if model else None,
334
- "available_qwen_classes": available_classes,
335
- "transformers_version": transformers.__version__
336
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import gradio as gr
4
  from PIL import Image
 
 
5
  import torch
6
  import re
7
  import logging
8
+ from typing import Tuple, List
9
+ import base64
10
+ from io import BytesIO
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
 
20
  model_name = "microsoft/GUI-Actor-2B-Qwen2-VL"
21
  model_loaded = False
22
 
23
+ def load_model():
24
  """Load model with proper error handling and fallback strategies"""
25
  global model, processor, tokenizer, model_loaded
26
 
 
39
 
40
  model = Qwen2VLForConditionalGeneration.from_pretrained(
41
  model_name,
42
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
43
+ device_map="auto" if torch.cuda.is_available() else None,
44
  trust_remote_code=True,
45
  low_cpu_mem_usage=True
46
  ).eval()
 
52
  logger.info("Trying AutoProcessor and AutoModel fallback...")
53
 
54
  try:
55
+ from transformers import AutoProcessor, AutoModelForVision2Seq
56
 
57
  processor = AutoProcessor.from_pretrained(
58
  model_name,
59
  trust_remote_code=True
60
  )
61
 
62
+ model = AutoModelForVision2Seq.from_pretrained(
63
  model_name,
64
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
65
+ device_map="auto" if torch.cuda.is_available() else None,
66
  trust_remote_code=True,
67
  low_cpu_mem_usage=True
68
  ).eval()
 
74
  logger.info("Trying generic transformers approach...")
75
 
76
  # Last fallback - try loading as generic model
77
+ from transformers import AutoConfig, AutoProcessor
78
  import transformers
79
 
80
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
 
96
  model = ModelClass.from_pretrained(
97
  model_name,
98
  config=config,
99
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
100
+ device_map="auto" if torch.cuda.is_available() else None,
101
  trust_remote_code=True,
102
  low_cpu_mem_usage=True
103
  ).eval()
 
116
  model_loaded = False
117
  return False
118
 
119
+ def extract_coordinates(text: str) -> List[Tuple[float, float]]:
120
+ """Extract coordinates from model output text"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Pattern untuk mencari koordinat dalam berbagai format
122
  patterns = [
123
  r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y)
 
143
  # Default ke center jika tidak ditemukan
144
  return [(0.5, 0.5)]
145
 
146
+ @spaces.GPU # Decorator untuk menggunakan GPU di Hugging Face Spaces
147
+ def inference(pil_image: Image.Image, instruction: str):
148
+ """Inference function with Spaces GPU support"""
149
+ if not model_loaded:
150
+ return "Model not loaded properly", 0.5, 0.5
151
+
152
  try:
153
+ conversation = [
154
+ {
155
+ "role": "system",
156
+ "content": [
157
+ {
158
+ "type": "text",
159
+ "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.",
160
+ }
161
+ ]
162
+ },
163
+ {
164
+ "role": "user",
165
+ "content": [
166
+ {
167
+ "type": "image",
168
+ "image": pil_image,
169
+ },
170
+ {
171
+ "type": "text",
172
+ "text": instruction,
173
+ },
174
+ ],
175
+ },
176
+ ]
177
+
178
  # Apply chat template
179
  text = processor.apply_chat_template(
180
  conversation,
 
190
  text=[text],
191
  images=[image],
192
  return_tensors="pt",
193
+ padding=True,
194
+ truncation=True,
195
+ max_length=512
196
  )
197
 
198
+ # Move inputs to the same device as model
199
+ if torch.cuda.is_available():
200
+ inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
201
+
202
  # Generate response with proper error handling
203
  with torch.no_grad():
204
  try:
 
226
 
227
  # Extract coordinates
228
  coordinates = extract_coordinates(response)
229
+ px, py = coordinates[0]
230
 
231
+ return response, round(px, 4), round(py, 4)
 
 
 
 
232
 
233
  except Exception as e:
234
  logger.error(f"Inference error: {e}")
235
+ return f"Error during inference: {str(e)}", 0.5, 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ def process_image(image: Image.Image, instruction: str):
238
+ """Process the uploaded image and instruction"""
239
+ if image is None:
240
+ return "Please upload an image", 0.5, 0.5
 
 
 
241
 
242
+ if not instruction.strip():
243
+ return "Please provide an instruction", 0.5, 0.5
244
+
245
+ # Convert image to RGB if needed
246
+ if image.mode != "RGB":
247
+ image = image.convert("RGB")
248
+
249
+ # Run inference
250
+ response, x, y = inference(image, instruction)
251
+
252
+ return response, x, y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ # Load model on startup
255
+ logger.info("Loading model...")
256
+ load_model()
 
 
 
 
 
 
257
 
258
+ # Create Gradio interface
259
+ with gr.Blocks(title="GUI-Actor Click Prediction", theme=gr.themes.Soft()) as demo:
260
+ gr.Markdown("# GUI-Actor Click Prediction")
261
+ gr.Markdown("Upload a screenshot and provide instructions to get click coordinates prediction.")
 
262
 
263
+ with gr.Row():
264
+ with gr.Column():
265
+ image_input = gr.Image(
266
+ type="pil",
267
+ label="Upload Screenshot",
268
+ height=400
269
+ )
270
+ instruction_input = gr.Textbox(
271
+ label="Instruction",
272
+ placeholder="e.g., Click on the login button",
273
+ lines=3
274
+ )
275
+ submit_btn = gr.Button("Predict Click Location", variant="primary")
276
+
277
+ with gr.Column():
278
+ response_output = gr.Textbox(
279
+ label="Model Response",
280
+ lines=5,
281
+ interactive=False
282
+ )
283
+ with gr.Row():
284
+ x_output = gr.Number(
285
+ label="X Coordinate (normalized)",
286
+ precision=4,
287
+ interactive=False
288
+ )
289
+ y_output = gr.Number(
290
+ label="Y Coordinate (normalized)",
291
+ precision=4,
292
+ interactive=False
293
+ )
294
+
295
+ # Status indicator
296
+ with gr.Row():
297
+ gr.Markdown(f"**Model Status:** {'✅ Loaded' if model_loaded else '❌ Not Loaded'}")
298
+ gr.Markdown(f"**Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}")
299
+
300
+ # Examples
301
+ gr.Examples(
302
+ examples=[
303
+ ["Click on the search button", None],
304
+ ["Select the dropdown menu", None],
305
+ ["Click on the submit form", None],
306
+ ],
307
+ inputs=[instruction_input, image_input],
308
+ )
309
+
310
+ # Event handlers
311
+ submit_btn.click(
312
+ fn=process_image,
313
+ inputs=[image_input, instruction_input],
314
+ outputs=[response_output, x_output, y_output]
315
+ )
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch(
319
+ server_name="0.0.0.0",
320
+ server_port=7860,
321
+ share=True
322
+ )