from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from typing import List, Dict, Tuple, Optional import base64 import io import os from pathlib import Path import torch import numpy as np from PIL import Image from ultralytics import YOLO from transformers import AutoProcessor, AutoModelForCausalLM,Blip2ForConditionalGeneration # Type definitions class ProcessResponse(BaseModel): image: str = Field(..., description="Base64 encoded processed image") parsed_content_list: str = Field(..., description="List of parsed content") label_coordinates: str = Field(..., description="Coordinates of detected labels") class ModelManager: def __init__(self): self.yolo_model = None self.processor = None self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_models(self): """Initialize all required models""" try: # Load YOLO model weights_path = Path("weights/icon_detect/best.pt") if not weights_path.exists(): raise FileNotFoundError(f"YOLO weights not found at {weights_path}") self.yolo_model = YOLO(str(weights_path)).to(self.device) # Load processor and model self.processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) self.model = Blip2ForConditionalGeneration.from_pretrained("banao-tech/OmniParser",torch_dtype=torch.float16, trust_remote_code=True).to(self.device) return True except Exception as e: print(f"Error loading models: {str(e)}") return False class ImageProcessor: def __init__(self, model_manager: ModelManager): self.model_manager = model_manager self.temp_dir = Path("temp") self.temp_dir.mkdir(exist_ok=True) async def process_image( self, image: Image.Image, box_threshold: float = 0.05, iou_threshold: float = 0.1 ) -> ProcessResponse: """Process the input image and return results""" try: # Save temporary image temp_image_path = self.temp_dir / "temp_image.png" image.save(temp_image_path) # Calculate overlay ratio box_overlay_ratio = image.size[0] / 3200 draw_config = self._get_draw_config(box_overlay_ratio) # Process image ocr_results = self._perform_ocr(temp_image_path) labeled_results = self._get_labeled_image( temp_image_path, ocr_results, box_threshold, iou_threshold, draw_config ) # Create response response = self._create_response(labeled_results) # Cleanup temp_image_path.unlink(missing_ok=True) return response except Exception as e: raise HTTPException( status_code=500, detail=f"Image processing failed: {str(e)}" ) def _get_draw_config(self, ratio: float) -> Dict: """Generate drawing configuration based on image ratio""" return { "text_scale": 0.8 * ratio, "text_thickness": max(int(2 * ratio), 1), "text_padding": max(int(3 * ratio), 1), "thickness": max(int(3 * ratio), 1), } def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]: """Perform OCR on the image""" # Implement OCR logic here # This is a placeholder - implement actual OCR logic return [], [] def _get_labeled_image( self, image_path: Path, ocr_results: Tuple[List[str], List], box_threshold: float, iou_threshold: float, draw_config: Dict ) -> Tuple[str, Dict, List[str]]: """Get labeled image with detected objects""" # Implement labeling logic here # This is a placeholder - implement actual labeling logic return "", {}, [] def _create_response( self, labeled_results: Tuple[str, Dict, List[str]] ) -> ProcessResponse: """Create API response from processing results""" labeled_image, coordinates, content_list = labeled_results return ProcessResponse( image=labeled_image, parsed_content_list="\n".join(content_list), label_coordinates=str(coordinates) ) # Initialize FastAPI app app = FastAPI( title="Image Processing API", description="API for processing and analyzing images", version="1.0.0" ) # Initialize model manager and image processor model_manager = ModelManager() image_processor = ImageProcessor(model_manager) @app.on_event("startup") async def startup_event(): """Initialize models on startup""" if not model_manager.load_models(): raise RuntimeError("Failed to load required models") @app.post( "/process_image", response_model=ProcessResponse, summary="Process an uploaded image", response_description="Processed image results" ) async def process_image( image_file: UploadFile = File(...), box_threshold: float = Query(0.05, ge=0, le=1), iou_threshold: float = Query(0.1, ge=0, le=1) ): """ Process an uploaded image file and return the results. Parameters: - image_file: The image file to process - box_threshold: Threshold for box detection (0-1) - iou_threshold: IOU threshold for overlap detection (0-1) Returns: - ProcessResponse containing the processed image and results """ try: # Validate file type if not image_file.content_type.startswith('image/'): raise HTTPException( status_code=400, detail="File must be an image" ) # Read and validate image contents = await image_file.read() try: image = Image.open(io.BytesIO(contents)).convert("RGB") except Exception as e: raise HTTPException( status_code=400, detail="Invalid image format" ) # Process image return await image_processor.process_image( image, box_threshold, iou_threshold ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=500, detail=f"Internal server error: {str(e)}" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)