Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Tuple, Optional | |
import base64 | |
import io | |
import os | |
from pathlib import Path | |
import torch | |
import numpy as np | |
from PIL import Image | |
from ultralytics import YOLO | |
from transformers import AutoProcessor, AutoModelForCausalLM,Blip2ForConditionalGeneration | |
# Type definitions | |
class ProcessResponse(BaseModel): | |
image: str = Field(..., description="Base64 encoded processed image") | |
parsed_content_list: str = Field(..., description="List of parsed content") | |
label_coordinates: str = Field(..., description="Coordinates of detected labels") | |
class ModelManager: | |
def __init__(self): | |
self.yolo_model = None | |
self.processor = None | |
self.model = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_models(self): | |
"""Initialize all required models""" | |
try: | |
# Load YOLO model | |
weights_path = Path("weights/icon_detect/best.pt") | |
if not weights_path.exists(): | |
raise FileNotFoundError(f"YOLO weights not found at {weights_path}") | |
self.yolo_model = YOLO(str(weights_path)).to(self.device) | |
# Load processor and model | |
self.processor = AutoProcessor.from_pretrained( | |
"microsoft/Florence-2-base", | |
trust_remote_code=True | |
) | |
self.model = Blip2ForConditionalGeneration.from_pretrained("banao-tech/OmniParser",torch_dtype=torch.float16, | |
trust_remote_code=True).to(self.device) | |
return True | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
return False | |
class ImageProcessor: | |
def __init__(self, model_manager: ModelManager): | |
self.model_manager = model_manager | |
self.temp_dir = Path("temp") | |
self.temp_dir.mkdir(exist_ok=True) | |
async def process_image( | |
self, | |
image: Image.Image, | |
box_threshold: float = 0.05, | |
iou_threshold: float = 0.1 | |
) -> ProcessResponse: | |
"""Process the input image and return results""" | |
try: | |
# Save temporary image | |
temp_image_path = self.temp_dir / "temp_image.png" | |
image.save(temp_image_path) | |
# Calculate overlay ratio | |
box_overlay_ratio = image.size[0] / 3200 | |
draw_config = self._get_draw_config(box_overlay_ratio) | |
# Process image | |
ocr_results = self._perform_ocr(temp_image_path) | |
labeled_results = self._get_labeled_image( | |
temp_image_path, | |
ocr_results, | |
box_threshold, | |
iou_threshold, | |
draw_config | |
) | |
# Create response | |
response = self._create_response(labeled_results) | |
# Cleanup | |
temp_image_path.unlink(missing_ok=True) | |
return response | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Image processing failed: {str(e)}" | |
) | |
def _get_draw_config(self, ratio: float) -> Dict: | |
"""Generate drawing configuration based on image ratio""" | |
return { | |
"text_scale": 0.8 * ratio, | |
"text_thickness": max(int(2 * ratio), 1), | |
"text_padding": max(int(3 * ratio), 1), | |
"thickness": max(int(3 * ratio), 1), | |
} | |
def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]: | |
"""Perform OCR on the image""" | |
# Implement OCR logic here | |
# This is a placeholder - implement actual OCR logic | |
return [], [] | |
def _get_labeled_image( | |
self, | |
image_path: Path, | |
ocr_results: Tuple[List[str], List], | |
box_threshold: float, | |
iou_threshold: float, | |
draw_config: Dict | |
) -> Tuple[str, Dict, List[str]]: | |
"""Get labeled image with detected objects""" | |
# Implement labeling logic here | |
# This is a placeholder - implement actual labeling logic | |
return "", {}, [] | |
def _create_response( | |
self, | |
labeled_results: Tuple[str, Dict, List[str]] | |
) -> ProcessResponse: | |
"""Create API response from processing results""" | |
labeled_image, coordinates, content_list = labeled_results | |
return ProcessResponse( | |
image=labeled_image, | |
parsed_content_list="\n".join(content_list), | |
label_coordinates=str(coordinates) | |
) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Image Processing API", | |
description="API for processing and analyzing images", | |
version="1.0.0" | |
) | |
# Initialize model manager and image processor | |
model_manager = ModelManager() | |
image_processor = ImageProcessor(model_manager) | |
async def startup_event(): | |
"""Initialize models on startup""" | |
if not model_manager.load_models(): | |
raise RuntimeError("Failed to load required models") | |
async def process_image( | |
image_file: UploadFile = File(...), | |
box_threshold: float = Query(0.05, ge=0, le=1), | |
iou_threshold: float = Query(0.1, ge=0, le=1) | |
): | |
""" | |
Process an uploaded image file and return the results. | |
Parameters: | |
- image_file: The image file to process | |
- box_threshold: Threshold for box detection (0-1) | |
- iou_threshold: IOU threshold for overlap detection (0-1) | |
Returns: | |
- ProcessResponse containing the processed image and results | |
""" | |
try: | |
# Validate file type | |
if not image_file.content_type.startswith('image/'): | |
raise HTTPException( | |
status_code=400, | |
detail="File must be an image" | |
) | |
# Read and validate image | |
contents = await image_file.read() | |
try: | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
except Exception as e: | |
raise HTTPException( | |
status_code=400, | |
detail="Invalid image format" | |
) | |
# Process image | |
return await image_processor.process_image( | |
image, | |
box_threshold, | |
iou_threshold | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Internal server error: {str(e)}" | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |