Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from PIL import Image | |
from io import BytesIO | |
import base64 | |
import torch | |
import re | |
import logging | |
import asyncio | |
from contextlib import asynccontextmanager | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize global variables | |
model = None | |
processor = None | |
tokenizer = None | |
model_name = "microsoft/GUI-Actor-2B-Qwen2-VL" | |
model_loaded = False | |
async def load_model(): | |
"""Load model with proper error handling and fallback strategies""" | |
global model, processor, tokenizer, model_loaded | |
try: | |
logger.info("Starting model loading...") | |
# Try specific Qwen2VL classes first | |
try: | |
logger.info("Attempting to load with Qwen2VL specific classes...") | |
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration | |
processor = Qwen2VLProcessor.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
# Configure padding for processor | |
if hasattr(processor, 'tokenizer'): | |
processor.tokenizer.padding_side = "left" # Important for Qwen2-VL | |
if processor.tokenizer.pad_token is None: | |
processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
device_map=None, # CPU only | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
).eval() | |
logger.info("Successfully loaded with Qwen2VL specific classes") | |
except Exception as e1: | |
logger.warning(f"Failed with Qwen2VL classes: {e1}") | |
logger.info("Trying AutoProcessor and AutoModel fallback...") | |
try: | |
from transformers import AutoProcessor, AutoModel | |
processor = AutoProcessor.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
# Configure padding for processor | |
if hasattr(processor, 'tokenizer'): | |
processor.tokenizer.padding_side = "left" | |
if processor.tokenizer.pad_token is None: | |
processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
model = AutoModel.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
device_map=None, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
).eval() | |
logger.info("Successfully loaded with Auto classes") | |
except Exception as e2: | |
logger.warning(f"Failed with Auto classes: {e2}") | |
logger.info("Trying generic transformers approach...") | |
# Last fallback - try loading as generic model | |
from transformers import AutoConfig, AutoTokenizer | |
import transformers | |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
logger.info(f"Model config type: {type(config)}") | |
# Try to find the right model class | |
if hasattr(transformers, 'Qwen2VLForConditionalGeneration'): | |
ModelClass = getattr(transformers, 'Qwen2VLForConditionalGeneration') | |
elif hasattr(transformers, 'AutoModelForVision2Seq'): | |
ModelClass = getattr(transformers, 'AutoModelForVision2Seq') | |
else: | |
raise Exception("No suitable model class found") | |
processor = AutoProcessor.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
# Configure padding | |
if hasattr(processor, 'tokenizer'): | |
processor.tokenizer.padding_side = "left" | |
if processor.tokenizer.pad_token is None: | |
processor.tokenizer.pad_token = processor.tokenizer.eos_token | |
model = ModelClass.from_pretrained( | |
model_name, | |
config=config, | |
torch_dtype=torch.float32, | |
device_map=None, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
).eval() | |
# Verify processor and model are loaded | |
if processor is None or model is None: | |
raise Exception("Failed to load processor or model") | |
tokenizer = processor.tokenizer | |
logger.info("Model and processor loaded successfully!") | |
model_loaded = True | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
model_loaded = False | |
return False | |
async def lifespan(app: FastAPI): | |
# Startup | |
logger.info("Starting up GUI-Actor API...") | |
await load_model() | |
yield | |
# Shutdown | |
logger.info("Shutting down GUI-Actor API...") | |
# Initialize FastAPI app with lifespan | |
app = FastAPI( | |
title="GUI-Actor API", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
class Base64Request(BaseModel): | |
image_base64: str | |
instruction: str | |
def extract_coordinates(text): | |
""" | |
Extract coordinates from model output text | |
""" | |
# Pattern untuk mencari koordinat dalam berbagai format | |
patterns = [ | |
r'click\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # click(x, y) | |
r'\[\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\]', # [x, y] | |
r'(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)', # x, y | |
r'point:\s*\(\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*\)', # point: (x, y) | |
] | |
for pattern in patterns: | |
matches = re.findall(pattern, text.lower()) | |
if matches: | |
try: | |
x, y = float(matches[0][0]), float(matches[0][1]) | |
# Normalize jika koordinat > 1 (asumsi pixel coordinates) | |
if x > 1 or y > 1: | |
# Asumsi resolusi 1920x1080 untuk normalisasi | |
x = x / 1920 if x > 1 else x | |
y = y / 1080 if y > 1 else y | |
return [(x, y)] | |
except (ValueError, IndexError): | |
continue | |
# Default ke center jika tidak ditemukan | |
return [(0.5, 0.5)] | |
def cpu_inference(conversation, model, tokenizer, processor): | |
try: | |
# Apply chat template | |
prompt = processor.apply_chat_template( | |
conversation, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
image = conversation[1]["content"][0]["image"] | |
# FIXED: Process inputs dengan padding yang benar | |
inputs = processor( | |
text=[prompt], # Wrap dalam list untuk batch processing | |
images=[image], # Wrap dalam list untuk batch processing | |
return_tensors="pt", | |
padding=True, # Enable padding | |
truncation=True, | |
max_length=512 | |
) | |
# FIXED: Pastikan semua tensor memiliki batch dimension yang konsisten | |
for key, value in inputs.items(): | |
if isinstance(value, torch.Tensor): | |
logger.debug(f"Input {key} shape: {value.shape}") | |
# FIXED: Set pad_token_id jika belum ada | |
pad_token_id = tokenizer.pad_token_id | |
if pad_token_id is None: | |
pad_token_id = tokenizer.eos_token_id | |
if pad_token_id is None: | |
pad_token_id = 0 # Fallback | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.8, | |
pad_token_id=pad_token_id, | |
attention_mask=inputs.get('attention_mask', None) # FIXED: Explicit attention mask | |
) | |
# FIXED: Extract generated tokens correctly | |
input_length = inputs["input_ids"].shape[1] | |
generated_ids = outputs[0][input_length:] | |
response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
coordinates = extract_coordinates(response) | |
return { | |
"topk_points": coordinates, | |
"response": response.strip(), | |
"success": True | |
} | |
except Exception as e: | |
logger.error(f"Inference error: {e}") | |
# FIXED: More detailed error logging | |
import traceback | |
logger.error(f"Full traceback: {traceback.format_exc()}") | |
return { | |
"topk_points": [(0.5, 0.5)], | |
"response": f"Error during inference: {str(e)}", | |
"success": False | |
} | |
async def root(): | |
return { | |
"message": "GUI-Actor API is running", | |
"status": "healthy", | |
"model_loaded": model_loaded, | |
"model_name": model_name | |
} | |
async def predict_click_base64(data: Base64Request): | |
if not model_loaded: | |
raise HTTPException( | |
status_code=503, | |
detail="Model not loaded properly" | |
) | |
try: | |
# Decode base64 to image | |
try: | |
# Handle data URL format | |
if "," in data.image_base64: | |
image_data = base64.b64decode(data.image_base64.split(",")[-1]) | |
else: | |
image_data = base64.b64decode(data.image_base64) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Invalid base64 image: {e}") | |
try: | |
pil_image = Image.open(BytesIO(image_data)).convert("RGB") | |
# FIXED: Log image dimensions for debugging | |
logger.debug(f"Image dimensions: {pil_image.size}") | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Invalid image format: {e}") | |
# FIXED: Improved conversation structure | |
conversation = [ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. Please provide the click coordinates.", | |
} | |
] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": pil_image, | |
}, | |
{ | |
"type": "text", | |
"text": data.instruction, | |
}, | |
], | |
}, | |
] | |
# Run inference | |
pred = cpu_inference(conversation, model, tokenizer, processor) | |
if not pred["success"]: | |
logger.warning(f"Inference failed: {pred['response']}") | |
px, py = pred["topk_points"][0] | |
return JSONResponse(content={ | |
"x": round(px, 4), | |
"y": round(py, 4), | |
"response": pred["response"], | |
"success": pred["success"] | |
}) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Prediction error: {e}") | |
import traceback | |
logger.error(f"Full traceback: {traceback.format_exc()}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Internal server error: {str(e)}" | |
) | |
async def health_check(): | |
return { | |
"status": "healthy" if model_loaded else "unhealthy", | |
"model": model_name, | |
"device": "cpu", | |
"torch_dtype": "float32", | |
"model_loaded": model_loaded | |
} | |
async def debug_info(): | |
"""Debug endpoint to check model loading status""" | |
import transformers | |
available_classes = [attr for attr in dir(transformers) if 'Qwen' in attr or 'VL' in attr] | |
debug_info = { | |
"model_loaded": model_loaded, | |
"processor_type": type(processor).__name__ if processor else None, | |
"model_type": type(model).__name__ if model else None, | |
"available_qwen_classes": available_classes, | |
"transformers_version": transformers.__version__ | |
} | |
# FIXED: Add tokenizer info for debugging | |
if processor and hasattr(processor, 'tokenizer'): | |
debug_info.update({ | |
"tokenizer_type": type(processor.tokenizer).__name__, | |
"pad_token": processor.tokenizer.pad_token, | |
"pad_token_id": processor.tokenizer.pad_token_id, | |
"eos_token": processor.tokenizer.eos_token, | |
"eos_token_id": processor.tokenizer.eos_token_id, | |
"padding_side": processor.tokenizer.padding_side | |
}) | |
return debug_info |