Spaces:
Sleeping
Sleeping
File size: 6,849 Bytes
cc59622 4e33759 f36b296 4e33759 f36b296 4e33759 f36b296 4e33759 1ec8443 52027db f36b296 cc59622 f36b296 1ec8443 0141f51 f36b296 52027db f36b296 4e33759 cc59622 4e33759 f36b296 4e33759 f36b296 4e33759 f36b296 4e33759 f36b296 4b6cfea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Tuple, Optional
import base64
import io
import os
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM,Blip2ForConditionalGeneration
# Type definitions
class ProcessResponse(BaseModel):
image: str = Field(..., description="Base64 encoded processed image")
parsed_content_list: str = Field(..., description="List of parsed content")
label_coordinates: str = Field(..., description="Coordinates of detected labels")
class ModelManager:
def __init__(self):
self.yolo_model = None
self.processor = None
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(self):
"""Initialize all required models"""
try:
# Load YOLO model
weights_path = Path("weights/icon_detect/best.pt")
if not weights_path.exists():
raise FileNotFoundError(f"YOLO weights not found at {weights_path}")
self.yolo_model = YOLO(str(weights_path)).to(self.device)
# Load processor and model
self.processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base",
trust_remote_code=True
)
self.model = Blip2ForConditionalGeneration.from_pretrained("banao-tech/OmniParser",torch_dtype=torch.float16,
trust_remote_code=True).to(self.device)
return True
except Exception as e:
print(f"Error loading models: {str(e)}")
return False
class ImageProcessor:
def __init__(self, model_manager: ModelManager):
self.model_manager = model_manager
self.temp_dir = Path("temp")
self.temp_dir.mkdir(exist_ok=True)
async def process_image(
self,
image: Image.Image,
box_threshold: float = 0.05,
iou_threshold: float = 0.1
) -> ProcessResponse:
"""Process the input image and return results"""
try:
# Save temporary image
temp_image_path = self.temp_dir / "temp_image.png"
image.save(temp_image_path)
# Calculate overlay ratio
box_overlay_ratio = image.size[0] / 3200
draw_config = self._get_draw_config(box_overlay_ratio)
# Process image
ocr_results = self._perform_ocr(temp_image_path)
labeled_results = self._get_labeled_image(
temp_image_path,
ocr_results,
box_threshold,
iou_threshold,
draw_config
)
# Create response
response = self._create_response(labeled_results)
# Cleanup
temp_image_path.unlink(missing_ok=True)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Image processing failed: {str(e)}"
)
def _get_draw_config(self, ratio: float) -> Dict:
"""Generate drawing configuration based on image ratio"""
return {
"text_scale": 0.8 * ratio,
"text_thickness": max(int(2 * ratio), 1),
"text_padding": max(int(3 * ratio), 1),
"thickness": max(int(3 * ratio), 1),
}
def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]:
"""Perform OCR on the image"""
# Implement OCR logic here
# This is a placeholder - implement actual OCR logic
return [], []
def _get_labeled_image(
self,
image_path: Path,
ocr_results: Tuple[List[str], List],
box_threshold: float,
iou_threshold: float,
draw_config: Dict
) -> Tuple[str, Dict, List[str]]:
"""Get labeled image with detected objects"""
# Implement labeling logic here
# This is a placeholder - implement actual labeling logic
return "", {}, []
def _create_response(
self,
labeled_results: Tuple[str, Dict, List[str]]
) -> ProcessResponse:
"""Create API response from processing results"""
labeled_image, coordinates, content_list = labeled_results
return ProcessResponse(
image=labeled_image,
parsed_content_list="\n".join(content_list),
label_coordinates=str(coordinates)
)
# Initialize FastAPI app
app = FastAPI(
title="Image Processing API",
description="API for processing and analyzing images",
version="1.0.0"
)
# Initialize model manager and image processor
model_manager = ModelManager()
image_processor = ImageProcessor(model_manager)
@app.on_event("startup")
async def startup_event():
"""Initialize models on startup"""
if not model_manager.load_models():
raise RuntimeError("Failed to load required models")
@app.post(
"/process_image",
response_model=ProcessResponse,
summary="Process an uploaded image",
response_description="Processed image results"
)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = Query(0.05, ge=0, le=1),
iou_threshold: float = Query(0.1, ge=0, le=1)
):
"""
Process an uploaded image file and return the results.
Parameters:
- image_file: The image file to process
- box_threshold: Threshold for box detection (0-1)
- iou_threshold: IOU threshold for overlap detection (0-1)
Returns:
- ProcessResponse containing the processed image and results
"""
try:
# Validate file type
if not image_file.content_type.startswith('image/'):
raise HTTPException(
status_code=400,
detail="File must be an image"
)
# Read and validate image
contents = await image_file.read()
try:
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(
status_code=400,
detail="Invalid image format"
)
# Process image
return await image_processor.process_image(
image,
box_threshold,
iou_threshold
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|