banao-tech commited on
Commit
f36b296
·
verified ·
1 Parent(s): 141751d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +197 -114
main.py CHANGED
@@ -1,132 +1,215 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel
4
- from typing import Optional
5
  import base64
6
  import io
7
- from PIL import Image
8
- import torch
9
- import numpy as np
10
  import os
11
-
12
- # Existing imports
13
- import numpy as np
14
  import torch
 
15
  from PIL import Image
16
- import io
17
-
18
- from utils import (
19
- check_ocr_box,
20
- get_yolo_model,
21
- get_caption_model_processor,
22
- get_som_labeled_img,
23
- )
24
- import torch
25
-
26
- # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
- # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
28
-
29
  from ultralytics import YOLO
30
-
31
- # if not os.path.exists("/data/icon_detect"):
32
- # os.makedirs("/data/icon_detect")
33
-
34
- try:
35
- yolo_model = YOLO("weights/icon_detect/best.pt").to("cpu")
36
- except:
37
- yolo_model = YOLO("weights/icon_detect/best.pt")
38
-
39
  from transformers import AutoProcessor, AutoModelForCausalLM
40
 
41
- processor = AutoProcessor.from_pretrained(
42
- "microsoft/Florence-2-base", trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
44
 
45
- try:
46
- model = AutoModelForCausalLM.from_pretrained(
47
- "banao-tech/OmniParse",
48
- torch_dtype=torch.float16,
49
- trust_remote_code=True,
50
- ).to("cpu")
51
- except:
52
- model = AutoModelForCausalLM.from_pretrained(
53
- "banao-tech/OmniParse",
54
- torch_dtype=torch.float16,
55
- trust_remote_code=True,
56
- )
57
- caption_model_processor = {"processor": processor, "model": model}
58
- print("finish loading model!!!")
59
-
60
- app = FastAPI()
61
-
62
-
63
- class ProcessResponse(BaseModel):
64
- image: str # Base64 encoded image
65
- parsed_content_list: str
66
- label_coordinates: str
67
-
68
-
69
- def process(
70
- image_input: Image.Image, box_threshold: float, iou_threshold: float
71
- ) -> ProcessResponse:
72
- image_save_path = "imgs/saved_image_demo.png"
73
- image_input.save(image_save_path)
74
- image = Image.open(image_save_path)
75
- box_overlay_ratio = image.size[0] / 3200
76
- draw_bbox_config = {
77
- "text_scale": 0.8 * box_overlay_ratio,
78
- "text_thickness": max(int(2 * box_overlay_ratio), 1),
79
- "text_padding": max(int(3 * box_overlay_ratio), 1),
80
- "thickness": max(int(3 * box_overlay_ratio), 1),
81
- }
82
-
83
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
84
- image_save_path,
85
- display_img=False,
86
- output_bb_format="xyxy",
87
- goal_filtering=None,
88
- easyocr_args={"paragraph": False, "text_threshold": 0.9},
89
- use_paddleocr=True,
90
- )
91
- text, ocr_bbox = ocr_bbox_rslt
92
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
93
- image_save_path,
94
- yolo_model,
95
- BOX_TRESHOLD=box_threshold,
96
- output_coord_in_ratio=True,
97
- ocr_bbox=ocr_bbox,
98
- draw_bbox_config=draw_bbox_config,
99
- caption_model_processor=caption_model_processor,
100
- ocr_text=text,
101
- iou_threshold=iou_threshold,
102
- )
103
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
- print("finish processing")
105
- parsed_content_list_str = "\n".join(parsed_content_list)
106
-
107
- # Encode image to base64
108
- buffered = io.BytesIO()
109
- image.save(buffered, format="PNG")
110
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
111
-
112
- return ProcessResponse(
113
- image=img_str,
114
- parsed_content_list=str(parsed_content_list_str),
115
- label_coordinates=str(label_coordinates),
116
- )
117
-
118
-
119
- @app.post("/process_image", response_model=ProcessResponse)
120
  async def process_image(
121
  image_file: UploadFile = File(...),
122
- box_threshold: float = 0.05,
123
- iou_threshold: float = 0.1,
124
  ):
 
 
 
 
 
 
 
 
 
 
 
125
  try:
 
 
 
 
 
 
 
 
126
  contents = await image_file.read()
127
- image_input = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
- raise HTTPException(status_code=400, detail="Invalid image file")
130
-
131
- response = process(image_input, box_threshold, iou_threshold)
132
- return response
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Dict, Tuple, Optional
5
  import base64
6
  import io
 
 
 
7
  import os
8
+ from pathlib import Path
 
 
9
  import torch
10
+ import numpy as np
11
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
13
  from transformers import AutoProcessor, AutoModelForCausalLM
14
 
15
+ # Type definitions
16
+ class ProcessResponse(BaseModel):
17
+ image: str = Field(..., description="Base64 encoded processed image")
18
+ parsed_content_list: str = Field(..., description="List of parsed content")
19
+ label_coordinates: str = Field(..., description="Coordinates of detected labels")
20
+
21
+ class ModelManager:
22
+ def __init__(self):
23
+ self.yolo_model = None
24
+ self.processor = None
25
+ self.model = None
26
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ def load_models(self):
29
+ """Initialize all required models"""
30
+ try:
31
+ # Load YOLO model
32
+ weights_path = Path("weights/icon_detect/best.pt")
33
+ if not weights_path.exists():
34
+ raise FileNotFoundError(f"YOLO weights not found at {weights_path}")
35
+ self.yolo_model = YOLO(str(weights_path)).to(self.device)
36
+
37
+ # Load processor and model
38
+ self.processor = AutoProcessor.from_pretrained(
39
+ "microsoft/Florence-2-base",
40
+ trust_remote_code=True
41
+ )
42
+
43
+ self.model = AutoModelForCausalLM.from_pretrained(
44
+ "banao-tech/OmniParse",
45
+ torch_dtype=torch.float16,
46
+ trust_remote_code=True
47
+ ).to(self.device)
48
+
49
+ return True
50
+ except Exception as e:
51
+ print(f"Error loading models: {str(e)}")
52
+ return False
53
+
54
+ class ImageProcessor:
55
+ def __init__(self, model_manager: ModelManager):
56
+ self.model_manager = model_manager
57
+ self.temp_dir = Path("temp")
58
+ self.temp_dir.mkdir(exist_ok=True)
59
+
60
+ async def process_image(
61
+ self,
62
+ image: Image.Image,
63
+ box_threshold: float = 0.05,
64
+ iou_threshold: float = 0.1
65
+ ) -> ProcessResponse:
66
+ """Process the input image and return results"""
67
+ try:
68
+ # Save temporary image
69
+ temp_image_path = self.temp_dir / "temp_image.png"
70
+ image.save(temp_image_path)
71
+
72
+ # Calculate overlay ratio
73
+ box_overlay_ratio = image.size[0] / 3200
74
+ draw_config = self._get_draw_config(box_overlay_ratio)
75
+
76
+ # Process image
77
+ ocr_results = self._perform_ocr(temp_image_path)
78
+ labeled_results = self._get_labeled_image(
79
+ temp_image_path,
80
+ ocr_results,
81
+ box_threshold,
82
+ iou_threshold,
83
+ draw_config
84
+ )
85
+
86
+ # Create response
87
+ response = self._create_response(labeled_results)
88
+
89
+ # Cleanup
90
+ temp_image_path.unlink(missing_ok=True)
91
+
92
+ return response
93
+
94
+ except Exception as e:
95
+ raise HTTPException(
96
+ status_code=500,
97
+ detail=f"Image processing failed: {str(e)}"
98
+ )
99
+
100
+ def _get_draw_config(self, ratio: float) -> Dict:
101
+ """Generate drawing configuration based on image ratio"""
102
+ return {
103
+ "text_scale": 0.8 * ratio,
104
+ "text_thickness": max(int(2 * ratio), 1),
105
+ "text_padding": max(int(3 * ratio), 1),
106
+ "thickness": max(int(3 * ratio), 1),
107
+ }
108
+
109
+ def _perform_ocr(self, image_path: Path) -> Tuple[List[str], List]:
110
+ """Perform OCR on the image"""
111
+ # Implement OCR logic here
112
+ # This is a placeholder - implement actual OCR logic
113
+ return [], []
114
+
115
+ def _get_labeled_image(
116
+ self,
117
+ image_path: Path,
118
+ ocr_results: Tuple[List[str], List],
119
+ box_threshold: float,
120
+ iou_threshold: float,
121
+ draw_config: Dict
122
+ ) -> Tuple[str, Dict, List[str]]:
123
+ """Get labeled image with detected objects"""
124
+ # Implement labeling logic here
125
+ # This is a placeholder - implement actual labeling logic
126
+ return "", {}, []
127
+
128
+ def _create_response(
129
+ self,
130
+ labeled_results: Tuple[str, Dict, List[str]]
131
+ ) -> ProcessResponse:
132
+ """Create API response from processing results"""
133
+ labeled_image, coordinates, content_list = labeled_results
134
+
135
+ return ProcessResponse(
136
+ image=labeled_image,
137
+ parsed_content_list="\n".join(content_list),
138
+ label_coordinates=str(coordinates)
139
+ )
140
+
141
+ # Initialize FastAPI app
142
+ app = FastAPI(
143
+ title="Image Processing API",
144
+ description="API for processing and analyzing images",
145
+ version="1.0.0"
146
  )
147
 
148
+ # Initialize model manager and image processor
149
+ model_manager = ModelManager()
150
+ image_processor = ImageProcessor(model_manager)
151
+
152
+ @app.on_event("startup")
153
+ async def startup_event():
154
+ """Initialize models on startup"""
155
+ if not model_manager.load_models():
156
+ raise RuntimeError("Failed to load required models")
157
+
158
+ @app.post(
159
+ "/process_image",
160
+ response_model=ProcessResponse,
161
+ summary="Process an uploaded image",
162
+ response_description="Processed image results"
163
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  async def process_image(
165
  image_file: UploadFile = File(...),
166
+ box_threshold: float = Field(0.05, ge=0, le=1),
167
+ iou_threshold: float = Field(0.1, ge=0, le=1)
168
  ):
169
+ """
170
+ Process an uploaded image file and return the results.
171
+
172
+ Parameters:
173
+ - image_file: The image file to process
174
+ - box_threshold: Threshold for box detection (0-1)
175
+ - iou_threshold: IOU threshold for overlap detection (0-1)
176
+
177
+ Returns:
178
+ - ProcessResponse containing the processed image and results
179
+ """
180
  try:
181
+ # Validate file type
182
+ if not image_file.content_type.startswith('image/'):
183
+ raise HTTPException(
184
+ status_code=400,
185
+ detail="File must be an image"
186
+ )
187
+
188
+ # Read and validate image
189
  contents = await image_file.read()
190
+ try:
191
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
192
+ except Exception as e:
193
+ raise HTTPException(
194
+ status_code=400,
195
+ detail="Invalid image format"
196
+ )
197
+
198
+ # Process image
199
+ return await image_processor.process_image(
200
+ image,
201
+ box_threshold,
202
+ iou_threshold
203
+ )
204
+
205
+ except HTTPException:
206
+ raise
207
  except Exception as e:
208
+ raise HTTPException(
209
+ status_code=500,
210
+ detail=f"Internal server error: {str(e)}"
211
+ )
212
+
213
+ if __name__ == "__main__":
214
+ import uvicorn
215
+ uvicorn.run(app, host="0.0.0.0", port=8000)