banao-tech commited on
Commit
d9070a8
·
verified ·
1 Parent(s): 7c2d0c5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -132
main.py CHANGED
@@ -1,132 +1,90 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from pydantic import BaseModel
4
- from typing import Optional
5
- import base64
6
- import io
7
- from PIL import Image
8
- import torch
9
- import numpy as np
10
- import os
11
-
12
- # Existing imports
13
- import numpy as np
14
- import torch
15
- from PIL import Image
16
- import io
17
-
18
- from utils import (
19
- check_ocr_box,
20
- get_yolo_model,
21
- get_caption_model_processor,
22
- get_som_labeled_img,
23
- )
24
- import torch
25
-
26
- # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
- # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
28
-
29
- from ultralytics import YOLO
30
-
31
- # if not os.path.exists("/data/icon_detect"):
32
- # os.makedirs("/data/icon_detect")
33
-
34
- try:
35
- yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
36
- except:
37
- yolo_model = YOLO("weights/icon_detect/best.pt")
38
-
39
- from transformers import AutoProcessor, AutoModelForCausalLM
40
-
41
- processor = AutoProcessor.from_pretrained(
42
- "microsoft/Florence-2-base", trust_remote_code=True
43
- )
44
-
45
- try:
46
- model = AutoModelForCausalLM.from_pretrained(
47
- "weights/icon_caption_florence",
48
- torch_dtype=torch.float16,
49
- trust_remote_code=True,
50
- ).to("cuda")
51
- except:
52
- model = AutoModelForCausalLM.from_pretrained(
53
- "weights/icon_caption_florence",
54
- torch_dtype=torch.float16,
55
- trust_remote_code=True,
56
- )
57
- caption_model_processor = {"processor": processor, "model": model}
58
- print("finish loading model!!!")
59
-
60
- app = FastAPI()
61
-
62
-
63
- class ProcessResponse(BaseModel):
64
- image: str # Base64 encoded image
65
- parsed_content_list: str
66
- label_coordinates: str
67
-
68
-
69
- def process(
70
- image_input: Image.Image, box_threshold: float, iou_threshold: float
71
- ) -> ProcessResponse:
72
- image_save_path = "imgs/saved_image_demo.png"
73
- image_input.save(image_save_path)
74
- image = Image.open(image_save_path)
75
- box_overlay_ratio = image.size[0] / 3200
76
- draw_bbox_config = {
77
- "text_scale": 0.8 * box_overlay_ratio,
78
- "text_thickness": max(int(2 * box_overlay_ratio), 1),
79
- "text_padding": max(int(3 * box_overlay_ratio), 1),
80
- "thickness": max(int(3 * box_overlay_ratio), 1),
81
- }
82
-
83
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
84
- image_save_path,
85
- display_img=False,
86
- output_bb_format="xyxy",
87
- goal_filtering=None,
88
- easyocr_args={"paragraph": False, "text_threshold": 0.9},
89
- use_paddleocr=True,
90
- )
91
- text, ocr_bbox = ocr_bbox_rslt
92
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
93
- image_save_path,
94
- yolo_model,
95
- BOX_TRESHOLD=box_threshold,
96
- output_coord_in_ratio=True,
97
- ocr_bbox=ocr_bbox,
98
- draw_bbox_config=draw_bbox_config,
99
- caption_model_processor=caption_model_processor,
100
- ocr_text=text,
101
- iou_threshold=iou_threshold,
102
- )
103
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
- print("finish processing")
105
- parsed_content_list_str = "\n".join(parsed_content_list)
106
-
107
- # Encode image to base64
108
- buffered = io.BytesIO()
109
- image.save(buffered, format="PNG")
110
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
111
-
112
- return ProcessResponse(
113
- image=img_str,
114
- parsed_content_list=str(parsed_content_list_str),
115
- label_coordinates=str(label_coordinates),
116
- )
117
-
118
-
119
- @app.post("/process_image", response_model=ProcessResponse)
120
- async def process_image(
121
- image_file: UploadFile = File(...),
122
- box_threshold: float = 0.05,
123
- iou_threshold: float = 0.1,
124
- ):
125
- try:
126
- contents = await image_file.read()
127
- image_input = Image.open(io.BytesIO(contents)).convert("RGB")
128
- except Exception as e:
129
- raise HTTPException(status_code=400, detail="Invalid image file")
130
-
131
- response = process(image_input, box_threshold, iou_threshold)
132
- return response
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ import base64
5
+ import io
6
+ from PIL import Image
7
+ import torch
8
+ from ultralytics import YOLO
9
+ from transformers import AutoProcessor, AutoModelForCausalLM
10
+ import os
11
+
12
+ # Import utility functions
13
+ from utils import check_ocr_box, get_som_labeled_img
14
+
15
+ # Initialize models and processor
16
+ try:
17
+ yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
18
+ except Exception as e:
19
+ raise RuntimeError(f"Error loading YOLO model: {e}")
20
+
21
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
22
+ try:
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True
25
+ ).to("cuda")
26
+ except Exception as e:
27
+ raise RuntimeError(f"Error loading captioning model: {e}")
28
+
29
+ caption_model_processor = {"processor": processor, "model": model}
30
+
31
+ # FastAPI app initialization
32
+ app = FastAPI()
33
+
34
+
35
+ class ProcessResponse(BaseModel):
36
+ image: str # Base64 encoded image
37
+ parsed_content_list: str
38
+ label_coordinates: str
39
+
40
+
41
+ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
42
+ image_save_path = "imgs/saved_image_demo.png"
43
+ image_input.save(image_save_path)
44
+
45
+ # Image processing and OCR
46
+ ocr_bbox_rslt, _ = check_ocr_box(
47
+ image_save_path, display_img=False, output_bb_format="xyxy", use_paddleocr=True
48
+ )
49
+ text, ocr_bbox = ocr_bbox_rslt
50
+
51
+ # Labeling the image with YOLO and captioning
52
+ dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
53
+ image_save_path,
54
+ yolo_model,
55
+ BOX_TRESHOLD=box_threshold,
56
+ output_coord_in_ratio=True,
57
+ ocr_bbox=ocr_bbox,
58
+ caption_model_processor=caption_model_processor,
59
+ ocr_text=text,
60
+ iou_threshold=iou_threshold,
61
+ )
62
+
63
+ # Convert labeled image to base64
64
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labeled_img)))
65
+ buffered = io.BytesIO()
66
+ image.save(buffered, format="PNG")
67
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
68
+
69
+ parsed_content_str = "\n".join(parsed_content_list)
70
+
71
+ return ProcessResponse(
72
+ image=img_str,
73
+ parsed_content_list=parsed_content_str,
74
+ label_coordinates=str(label_coordinates),
75
+ )
76
+
77
+
78
+ @app.post("/process_image", response_model=ProcessResponse)
79
+ async def process_image(
80
+ image_file: UploadFile = File(...),
81
+ box_threshold: float = 0.05,
82
+ iou_threshold: float = 0.1,
83
+ ):
84
+ try:
85
+ contents = await image_file.read()
86
+ image_input = Image.open(io.BytesIO(contents)).convert("RGB")
87
+ except Exception as e:
88
+ raise HTTPException(status_code=400, detail="Invalid image file")
89
+
90
+ return process(image_input, box_threshold, iou_threshold)