banao-tech commited on
Commit
a6945bb
·
verified ·
1 Parent(s): 0141f51

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +132 -212
main.py CHANGED
@@ -1,212 +1,132 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Query
2
- from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel, Field
4
- from typing import List, Dict, Tuple, Optional
5
- import base64
6
- import io
7
- import os
8
- from pathlib import Path
9
- import torch
10
- import numpy as np
11
- from PIL import Image
12
- from ultralytics import YOLO
13
- from transformers import AutoProcessor, AutoModelForCausalLM,Blip2ForConditionalGeneration
14
-
15
- # Type definitions
16
- class ProcessResponse(BaseModel):
17
- image: str = Field(..., description="Base64 encoded processed image")
18
- parsed_content_list: str = Field(..., description="List of parsed content")
19
- label_coordinates: str = Field(..., description="Coordinates of detected labels")
20
-
21
- class ModelManager:
22
- def __init__(self):
23
- self.yolo_model = None
24
- self.processor = None
25
- self.model = None
26
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
-
28
- def load_models(self):
29
- """Initialize all required models"""
30
- try:
31
- # Load YOLO model
32
- weights_path = Path("weights/icon_detect/best.pt")
33
- if not weights_path.exists():
34
- raise FileNotFoundError(f"YOLO weights not found at {weights_path}")
35
- self.yolo_model = YOLO(str(weights_path)).to(self.device)
36
-
37
- # Load processor and model
38
- self.processor = AutoProcessor.from_pretrained(
39
- "microsoft/Florence-2-base",
40
- trust_remote_code=True
41
- )
42
-
43
- self.model = Blip2ForConditionalGeneration.from_pretrained("banao-tech/OmniParser",torch_dtype=torch.float16,
44
- trust_remote_code=True).to(self.device)
45
-
46
- return True
47
- except Exception as e:
48
- print(f"Error loading models: {str(e)}")
49
- return False
50
-
51
- class ImageProcessor:
52
- def __init__(self, model_manager: ModelManager):
53
- self.model_manager = model_manager
54
- self.temp_dir = Path("temp")
55
- self.temp_dir.mkdir(exist_ok=True)
56
-
57
- async def process_image(
58
- self,
59
- image: Image.Image,
60
- box_threshold: float = 0.05,
61
- iou_threshold: float = 0.1
62
- ) -> ProcessResponse:
63
- """Process the input image and return results"""
64
- try:
65
- # Save temporary image
66
- temp_image_path = self.temp_dir / "temp_image.png"
67
- image.save(temp_image_path)
68
-
69
- # Calculate overlay ratio
70
- box_overlay_ratio = image.size[0] / 3200
71
- draw_config = self._get_draw_config(box_overlay_ratio)
72
-
73
- # Process image
74
- ocr_results = self._perform_ocr(temp_image_path)
75
- labeled_results = self._get_labeled_image(
76
- temp_image_path,
77
- ocr_results,
78
- box_threshold,
79
- iou_threshold,
80
- draw_config
81
- )
82
-
83
- # Create response
84
- response = self._create_response(labeled_results)
85
-
86
- # Cleanup
87
- temp_image_path.unlink(missing_ok=True)
88
-
89
- return response
90
-
91
- except Exception as e:
92
- raise HTTPException(
93
- status_code=500,
94
- detail=f"Image processing failed: {str(e)}"
95
- )
96
-
97
- def _get_draw_config(self, ratio: float) -> Dict:
98
- """Generate drawing configuration based on image ratio"""
99
- return {
100
- "text_scale": 0.8 * ratio,
101
- "text_thickness": max(int(2 * ratio), 1),
102
- "text_padding": max(int(3 * ratio), 1),
103
- "thickness": max(int(3 * ratio), 1),
104
- }
105
-
106
- def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]:
107
- """Perform OCR on the image"""
108
- # Implement OCR logic here
109
- # This is a placeholder - implement actual OCR logic
110
- return [], []
111
-
112
- def _get_labeled_image(
113
- self,
114
- image_path: Path,
115
- ocr_results: Tuple[List[str], List],
116
- box_threshold: float,
117
- iou_threshold: float,
118
- draw_config: Dict
119
- ) -> Tuple[str, Dict, List[str]]:
120
- """Get labeled image with detected objects"""
121
- # Implement labeling logic here
122
- # This is a placeholder - implement actual labeling logic
123
- return "", {}, []
124
-
125
- def _create_response(
126
- self,
127
- labeled_results: Tuple[str, Dict, List[str]]
128
- ) -> ProcessResponse:
129
- """Create API response from processing results"""
130
- labeled_image, coordinates, content_list = labeled_results
131
-
132
- return ProcessResponse(
133
- image=labeled_image,
134
- parsed_content_list="\n".join(content_list),
135
- label_coordinates=str(coordinates)
136
- )
137
-
138
- # Initialize FastAPI app
139
- app = FastAPI(
140
- title="Image Processing API",
141
- description="API for processing and analyzing images",
142
- version="1.0.0"
143
- )
144
-
145
- # Initialize model manager and image processor
146
- model_manager = ModelManager()
147
- image_processor = ImageProcessor(model_manager)
148
-
149
- @app.on_event("startup")
150
- async def startup_event():
151
- """Initialize models on startup"""
152
- if not model_manager.load_models():
153
- raise RuntimeError("Failed to load required models")
154
-
155
- @app.post(
156
- "/process_image",
157
- response_model=ProcessResponse,
158
- summary="Process an uploaded image",
159
- response_description="Processed image results"
160
- )
161
- async def process_image(
162
- image_file: UploadFile = File(...),
163
- box_threshold: float = Query(0.05, ge=0, le=1),
164
- iou_threshold: float = Query(0.1, ge=0, le=1)
165
- ):
166
- """
167
- Process an uploaded image file and return the results.
168
-
169
- Parameters:
170
- - image_file: The image file to process
171
- - box_threshold: Threshold for box detection (0-1)
172
- - iou_threshold: IOU threshold for overlap detection (0-1)
173
-
174
- Returns:
175
- - ProcessResponse containing the processed image and results
176
- """
177
- try:
178
- # Validate file type
179
- if not image_file.content_type.startswith('image/'):
180
- raise HTTPException(
181
- status_code=400,
182
- detail="File must be an image"
183
- )
184
-
185
- # Read and validate image
186
- contents = await image_file.read()
187
- try:
188
- image = Image.open(io.BytesIO(contents)).convert("RGB")
189
- except Exception as e:
190
- raise HTTPException(
191
- status_code=400,
192
- detail="Invalid image format"
193
- )
194
-
195
- # Process image
196
- return await image_processor.process_image(
197
- image,
198
- box_threshold,
199
- iou_threshold
200
- )
201
-
202
- except HTTPException:
203
- raise
204
- except Exception as e:
205
- raise HTTPException(
206
- status_code=500,
207
- detail=f"Internal server error: {str(e)}"
208
- )
209
-
210
- if __name__ == "__main__":
211
- import uvicorn
212
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+ # Existing imports
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+ import io
17
+
18
+ from utils import (
19
+ check_ocr_box,
20
+ get_yolo_model,
21
+ get_caption_model_processor,
22
+ get_som_labeled_img,
23
+ )
24
+ import torch
25
+
26
+ # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
+ # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
28
+
29
+ from ultralytics import YOLO
30
+
31
+ # if not os.path.exists("/data/icon_detect"):
32
+ # os.makedirs("/data/icon_detect")
33
+
34
+ try:
35
+ yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
36
+ except:
37
+ yolo_model = YOLO("weights/icon_detect/best.pt")
38
+
39
+ from transformers import AutoProcessor, AutoModelForCausalLM
40
+
41
+ processor = AutoProcessor.from_pretrained(
42
+ "microsoft/Florence-2-base", trust_remote_code=True
43
+ )
44
+
45
+ try:
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ "weights/icon_caption_florence",
48
+ torch_dtype=torch.float16,
49
+ trust_remote_code=True,
50
+ ).to("cuda")
51
+ except:
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ "weights/icon_caption_florence",
54
+ torch_dtype=torch.float16,
55
+ trust_remote_code=True,
56
+ )
57
+ caption_model_processor = {"processor": processor, "model": model}
58
+ print("finish loading model!!!")
59
+
60
+ app = FastAPI()
61
+
62
+
63
+ class ProcessResponse(BaseModel):
64
+ image: str # Base64 encoded image
65
+ parsed_content_list: str
66
+ label_coordinates: str
67
+
68
+
69
+ def process(
70
+ image_input: Image.Image, box_threshold: float, iou_threshold: float
71
+ ) -> ProcessResponse:
72
+ image_save_path = "imgs/saved_image_demo.png"
73
+ image_input.save(image_save_path)
74
+ image = Image.open(image_save_path)
75
+ box_overlay_ratio = image.size[0] / 3200
76
+ draw_bbox_config = {
77
+ "text_scale": 0.8 * box_overlay_ratio,
78
+ "text_thickness": max(int(2 * box_overlay_ratio), 1),
79
+ "text_padding": max(int(3 * box_overlay_ratio), 1),
80
+ "thickness": max(int(3 * box_overlay_ratio), 1),
81
+ }
82
+
83
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
84
+ image_save_path,
85
+ display_img=False,
86
+ output_bb_format="xyxy",
87
+ goal_filtering=None,
88
+ easyocr_args={"paragraph": False, "text_threshold": 0.9},
89
+ use_paddleocr=True,
90
+ )
91
+ text, ocr_bbox = ocr_bbox_rslt
92
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
93
+ image_save_path,
94
+ yolo_model,
95
+ BOX_TRESHOLD=box_threshold,
96
+ output_coord_in_ratio=True,
97
+ ocr_bbox=ocr_bbox,
98
+ draw_bbox_config=draw_bbox_config,
99
+ caption_model_processor=caption_model_processor,
100
+ ocr_text=text,
101
+ iou_threshold=iou_threshold,
102
+ )
103
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
+ print("finish processing")
105
+ parsed_content_list_str = "\n".join(parsed_content_list)
106
+
107
+ # Encode image to base64
108
+ buffered = io.BytesIO()
109
+ image.save(buffered, format="PNG")
110
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
111
+
112
+ return ProcessResponse(
113
+ image=img_str,
114
+ parsed_content_list=str(parsed_content_list_str),
115
+ label_coordinates=str(label_coordinates),
116
+ )
117
+
118
+
119
+ @app.post("/process_image", response_model=ProcessResponse)
120
+ async def process_image(
121
+ image_file: UploadFile = File(...),
122
+ box_threshold: float = 0.05,
123
+ iou_threshold: float = 0.1,
124
+ ):
125
+ try:
126
+ contents = await image_file.read()
127
+ image_input = Image.open(io.BytesIO(contents)).convert("RGB")
128
+ except Exception as e:
129
+ raise HTTPException(status_code=400, detail="Invalid image file")
130
+
131
+ response = process(image_input, box_threshold, iou_threshold)
132
+ return response