OmniPar / main.py
banao-tech's picture
Update main.py
0141f51 verified
raw
history blame
6.85 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Tuple, Optional
import base64
import io
import os
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM,Blip2ForConditionalGeneration
# Type definitions
class ProcessResponse(BaseModel):
image: str = Field(..., description="Base64 encoded processed image")
parsed_content_list: str = Field(..., description="List of parsed content")
label_coordinates: str = Field(..., description="Coordinates of detected labels")
class ModelManager:
def __init__(self):
self.yolo_model = None
self.processor = None
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(self):
"""Initialize all required models"""
try:
# Load YOLO model
weights_path = Path("weights/icon_detect/best.pt")
if not weights_path.exists():
raise FileNotFoundError(f"YOLO weights not found at {weights_path}")
self.yolo_model = YOLO(str(weights_path)).to(self.device)
# Load processor and model
self.processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
self.model = Blip2ForConditionalGeneration.from_pretrained("banao-tech/OmniParser",torch_dtype=torch.float16,
trust_remote_code=True).to(self.device)
return True
except Exception as e:
print(f"Error loading models: {str(e)}")
return False
class ImageProcessor:
def __init__(self, model_manager: ModelManager):
self.model_manager = model_manager
self.temp_dir = Path("temp")
self.temp_dir.mkdir(exist_ok=True)
async def process_image(
self,
image: Image.Image,
box_threshold: float = 0.05,
iou_threshold: float = 0.1
) -> ProcessResponse:
"""Process the input image and return results"""
try:
# Save temporary image
temp_image_path = self.temp_dir / "temp_image.png"
image.save(temp_image_path)
# Calculate overlay ratio
box_overlay_ratio = image.size[0] / 3200
draw_config = self._get_draw_config(box_overlay_ratio)
# Process image
ocr_results = self._perform_ocr(temp_image_path)
labeled_results = self._get_labeled_image(
temp_image_path,
ocr_results,
box_threshold,
iou_threshold,
draw_config
)
# Create response
response = self._create_response(labeled_results)
# Cleanup
temp_image_path.unlink(missing_ok=True)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Image processing failed: {str(e)}"
)
def _get_draw_config(self, ratio: float) -> Dict:
"""Generate drawing configuration based on image ratio"""
return {
"text_scale": 0.8 * ratio,
"text_thickness": max(int(2 * ratio), 1),
"text_padding": max(int(3 * ratio), 1),
"thickness": max(int(3 * ratio), 1),
}
def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]:
"""Perform OCR on the image"""
# Implement OCR logic here
# This is a placeholder - implement actual OCR logic
return [], []
def _get_labeled_image(
self,
image_path: Path,
ocr_results: Tuple[List[str], List],
box_threshold: float,
iou_threshold: float,
draw_config: Dict
) -> Tuple[str, Dict, List[str]]:
"""Get labeled image with detected objects"""
# Implement labeling logic here
# This is a placeholder - implement actual labeling logic
return "", {}, []
def _create_response(
self,
labeled_results: Tuple[str, Dict, List[str]]
) -> ProcessResponse:
"""Create API response from processing results"""
labeled_image, coordinates, content_list = labeled_results
return ProcessResponse(
image=labeled_image,
parsed_content_list="\n".join(content_list),
label_coordinates=str(coordinates)
)
# Initialize FastAPI app
app = FastAPI(
title="Image Processing API",
description="API for processing and analyzing images",
version="1.0.0"
)
# Initialize model manager and image processor
model_manager = ModelManager()
image_processor = ImageProcessor(model_manager)
@app.on_event("startup")
async def startup_event():
"""Initialize models on startup"""
if not model_manager.load_models():
raise RuntimeError("Failed to load required models")
@app.post(
"/process_image",
response_model=ProcessResponse,
summary="Process an uploaded image",
response_description="Processed image results"
)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = Query(0.05, ge=0, le=1),
iou_threshold: float = Query(0.1, ge=0, le=1)
):
"""
Process an uploaded image file and return the results.
Parameters:
- image_file: The image file to process
- box_threshold: Threshold for box detection (0-1)
- iou_threshold: IOU threshold for overlap detection (0-1)
Returns:
- ProcessResponse containing the processed image and results
"""
try:
# Validate file type
if not image_file.content_type.startswith('image/'):
raise HTTPException(
status_code=400,
detail="File must be an image"
)
# Read and validate image
contents = await image_file.read()
try:
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(
status_code=400,
detail="Invalid image format"
)
# Process image
return await image_processor.process_image(
image,
box_threshold,
iou_threshold
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)