File size: 6,849 Bytes
cc59622
4e33759
f36b296
 
4e33759
 
 
f36b296
4e33759
f36b296
4e33759
 
1ec8443
52027db
f36b296
 
 
 
 
 
 
 
 
 
 
 
 
cc59622
f36b296
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec8443
0141f51
f36b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52027db
 
f36b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e33759
 
cc59622
 
4e33759
f36b296
 
 
 
 
 
 
 
 
 
 
4e33759
f36b296
 
 
 
 
 
 
 
4e33759
f36b296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e33759
f36b296
 
 
 
 
 
 
4b6cfea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Tuple, Optional
import base64
import io
import os
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM,Blip2ForConditionalGeneration

# Type definitions
class ProcessResponse(BaseModel):
    image: str = Field(..., description="Base64 encoded processed image")
    parsed_content_list: str = Field(..., description="List of parsed content")
    label_coordinates: str = Field(..., description="Coordinates of detected labels")

class ModelManager:
    def __init__(self):
        self.yolo_model = None
        self.processor = None
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
    def load_models(self):
        """Initialize all required models"""
        try:
            # Load YOLO model
            weights_path = Path("weights/icon_detect/best.pt")
            if not weights_path.exists():
                raise FileNotFoundError(f"YOLO weights not found at {weights_path}")
            self.yolo_model = YOLO(str(weights_path)).to(self.device)

            # Load processor and model
            self.processor = AutoProcessor.from_pretrained(
                "microsoft/Florence-2-base",
                trust_remote_code=True
            )
            
            self.model = Blip2ForConditionalGeneration.from_pretrained("banao-tech/OmniParser",torch_dtype=torch.float16,
                trust_remote_code=True).to(self.device)
            
            return True
        except Exception as e:
            print(f"Error loading models: {str(e)}")
            return False

class ImageProcessor:
    def __init__(self, model_manager: ModelManager):
        self.model_manager = model_manager
        self.temp_dir = Path("temp")
        self.temp_dir.mkdir(exist_ok=True)

    async def process_image(
        self,
        image: Image.Image,
        box_threshold: float = 0.05,
        iou_threshold: float = 0.1
    ) -> ProcessResponse:
        """Process the input image and return results"""
        try:
            # Save temporary image
            temp_image_path = self.temp_dir / "temp_image.png"
            image.save(temp_image_path)

            # Calculate overlay ratio
            box_overlay_ratio = image.size[0] / 3200
            draw_config = self._get_draw_config(box_overlay_ratio)

            # Process image
            ocr_results = self._perform_ocr(temp_image_path)
            labeled_results = self._get_labeled_image(
                temp_image_path,
                ocr_results,
                box_threshold,
                iou_threshold,
                draw_config
            )

            # Create response
            response = self._create_response(labeled_results)
            
            # Cleanup
            temp_image_path.unlink(missing_ok=True)
            
            return response

        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail=f"Image processing failed: {str(e)}"
            )

    def _get_draw_config(self, ratio: float) -> Dict:
        """Generate drawing configuration based on image ratio"""
        return {
            "text_scale": 0.8 * ratio,
            "text_thickness": max(int(2 * ratio), 1),
            "text_padding": max(int(3 * ratio), 1),
            "thickness": max(int(3 * ratio), 1),
        }

    def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]:
        """Perform OCR on the image"""
        # Implement OCR logic here
        # This is a placeholder - implement actual OCR logic
        return [], []

    def _get_labeled_image(
        self,
        image_path: Path,
        ocr_results: Tuple[List[str], List],
        box_threshold: float,
        iou_threshold: float,
        draw_config: Dict
    ) -> Tuple[str, Dict, List[str]]:
        """Get labeled image with detected objects"""
        # Implement labeling logic here
        # This is a placeholder - implement actual labeling logic
        return "", {}, []

    def _create_response(
        self,
        labeled_results: Tuple[str, Dict, List[str]]
    ) -> ProcessResponse:
        """Create API response from processing results"""
        labeled_image, coordinates, content_list = labeled_results
        
        return ProcessResponse(
            image=labeled_image,
            parsed_content_list="\n".join(content_list),
            label_coordinates=str(coordinates)
        )

# Initialize FastAPI app
app = FastAPI(
    title="Image Processing API",
    description="API for processing and analyzing images",
    version="1.0.0"
)

# Initialize model manager and image processor
model_manager = ModelManager()
image_processor = ImageProcessor(model_manager)

@app.on_event("startup")
async def startup_event():
    """Initialize models on startup"""
    if not model_manager.load_models():
        raise RuntimeError("Failed to load required models")

@app.post(
    "/process_image",
    response_model=ProcessResponse,
    summary="Process an uploaded image",
    response_description="Processed image results"
)
async def process_image(
    image_file: UploadFile = File(...),
    box_threshold: float = Query(0.05, ge=0, le=1),
    iou_threshold: float = Query(0.1, ge=0, le=1)
):
    """
    Process an uploaded image file and return the results.
    
    Parameters:
    - image_file: The image file to process
    - box_threshold: Threshold for box detection (0-1)
    - iou_threshold: IOU threshold for overlap detection (0-1)
    
    Returns:
    - ProcessResponse containing the processed image and results
    """
    try:
        # Validate file type
        if not image_file.content_type.startswith('image/'):
            raise HTTPException(
                status_code=400,
                detail="File must be an image"
            )

        # Read and validate image
        contents = await image_file.read()
        try:
            image = Image.open(io.BytesIO(contents)).convert("RGB")
        except Exception as e:
            raise HTTPException(
                status_code=400,
                detail="Invalid image format"
            )

        # Process image
        return await image_processor.process_image(
            image,
            box_threshold,
            iou_threshold
        )

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Internal server error: {str(e)}"
        )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)