Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import spaces
|
2 |
import multiprocessing as mp
|
3 |
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
import torch
|
6 |
try:
|
7 |
import detectron2
|
@@ -32,6 +32,49 @@ def setup_cfg(config_file):
|
|
32 |
cfg.freeze()
|
33 |
return cfg
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@spaces.GPU
|
36 |
@torch.no_grad()
|
37 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
@@ -60,13 +103,12 @@ def inference_automatic(input_img, class_names):
|
|
60 |
@spaces.GPU
|
61 |
@torch.no_grad()
|
62 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
63 |
-
def inference_point(input_img,
|
64 |
|
65 |
|
66 |
mp.set_start_method("spawn", force=True)
|
67 |
|
68 |
-
|
69 |
-
points = [[x, y]]
|
70 |
print(f"Selected point: {points}")
|
71 |
|
72 |
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
|
@@ -74,10 +116,8 @@ def inference_point(input_img, evt: gr.SelectData,):
|
|
74 |
|
75 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
76 |
|
77 |
-
img = read_image(input_img, format="BGR")
|
78 |
-
|
79 |
|
80 |
-
_, visualized_output = demo.run_on_image_with_points(img, points)
|
81 |
return visualized_output
|
82 |
|
83 |
|
@@ -85,6 +125,21 @@ sam2_model = None
|
|
85 |
clip_model = None
|
86 |
mask_adapter = None
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
|
89 |
cfg = setup_cfg(cfg)
|
90 |
global sam2_model, clip_model, mask_adapter
|
@@ -107,6 +162,16 @@ def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
|
|
107 |
mask_adapter.load_state_dict(adapter_state_dict)
|
108 |
print("Mask Adapter model initialized.")
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
111 |
sam_path = './sam2.1_hiera_large.pt'
|
112 |
adapter_pth = './model_0279999_with_sem_new.pth'
|
@@ -162,21 +227,35 @@ with gr.Blocks() as demo:
|
|
162 |
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
|
163 |
|
164 |
with gr.TabItem("Point Mode"):
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
|
|
|
1 |
import spaces
|
2 |
import multiprocessing as mp
|
3 |
import numpy as np
|
4 |
+
from PIL import Image,ImageDraw
|
5 |
import torch
|
6 |
try:
|
7 |
import detectron2
|
|
|
32 |
cfg.freeze()
|
33 |
return cfg
|
34 |
|
35 |
+
class IMGState:
|
36 |
+
def __init__(self):
|
37 |
+
self.img = None
|
38 |
+
self.img_feat = None
|
39 |
+
self.selected_points = []
|
40 |
+
self.selected_points_labels = []
|
41 |
+
self.selected_bboxes = []
|
42 |
+
|
43 |
+
self.available_to_set = True
|
44 |
+
|
45 |
+
def set_img(self, img, img_feat):
|
46 |
+
self.img = img
|
47 |
+
self.img_feat = img_feat
|
48 |
+
|
49 |
+
self.available_to_set = False
|
50 |
+
|
51 |
+
def clear(self):
|
52 |
+
self.img = None
|
53 |
+
self.img_feat = None
|
54 |
+
self.selected_points = []
|
55 |
+
self.selected_points_labels = []
|
56 |
+
self.selected_bboxes = []
|
57 |
+
|
58 |
+
self.available_to_set = True
|
59 |
+
|
60 |
+
def clean(self):
|
61 |
+
self.selected_points = []
|
62 |
+
self.selected_points_labels = []
|
63 |
+
self.selected_bboxes = []
|
64 |
+
|
65 |
+
def to_device(self, device=torch.device("cuda")):
|
66 |
+
if self.img_feat is not None:
|
67 |
+
for k in self.img_feat:
|
68 |
+
if isinstance(self.img_feat[k], torch.Tensor):
|
69 |
+
self.img_feat[k] = self.img_feat[k].to(device)
|
70 |
+
elif isinstance(self.img_feat[k], tuple):
|
71 |
+
self.img_feat[k] = tuple(v.to(device) for v in self.img_feat[k])
|
72 |
+
|
73 |
+
@property
|
74 |
+
def available(self):
|
75 |
+
return self.available_to_set
|
76 |
+
|
77 |
+
|
78 |
@spaces.GPU
|
79 |
@torch.no_grad()
|
80 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
|
|
103 |
@spaces.GPU
|
104 |
@torch.no_grad()
|
105 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
106 |
+
def inference_point(input_img, img_state,):
|
107 |
|
108 |
|
109 |
mp.set_start_method("spawn", force=True)
|
110 |
|
111 |
+
points = img_state.selected_points
|
|
|
112 |
print(f"Selected point: {points}")
|
113 |
|
114 |
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
|
|
|
116 |
|
117 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
118 |
|
|
|
|
|
119 |
|
120 |
+
_, visualized_output = demo.run_on_image_with_points(img_state.img, points)
|
121 |
return visualized_output
|
122 |
|
123 |
|
|
|
125 |
clip_model = None
|
126 |
mask_adapter = None
|
127 |
|
128 |
+
def get_points_with_draw(image, img_state, evt: gr.SelectData):
|
129 |
+
label = 'Add Mask'
|
130 |
+
|
131 |
+
x, y = evt.index[0], evt.index[1]
|
132 |
+
point_radius, point_color = 10, (97, 217, 54) if label == "Add Mask" else (237, 34, 13)
|
133 |
+
|
134 |
+
img_state.selected_points.append([x, y])
|
135 |
+
img_state.selected_points_labels.append(1 if label == "Add Mask" else 0)
|
136 |
+
img_state.set_img(np.array(image), None)
|
137 |
+
draw = ImageDraw.Draw(image)
|
138 |
+
draw.ellipse(
|
139 |
+
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
140 |
+
fill=point_color,
|
141 |
+
)
|
142 |
+
return img_state, image
|
143 |
def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
|
144 |
cfg = setup_cfg(cfg)
|
145 |
global sam2_model, clip_model, mask_adapter
|
|
|
162 |
mask_adapter.load_state_dict(adapter_state_dict)
|
163 |
print("Mask Adapter model initialized.")
|
164 |
|
165 |
+
def clear_everything(img_state):
|
166 |
+
img_state.clear()
|
167 |
+
return img_state, None, None
|
168 |
+
|
169 |
+
|
170 |
+
def clean_prompts(img_state):
|
171 |
+
img_state.clean()
|
172 |
+
return img_state, Image.fromarray(img_state.img), None
|
173 |
+
|
174 |
+
# 初始化配置和模型
|
175 |
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
176 |
sam_path = './sam2.1_hiera_large.pt'
|
177 |
adapter_pth = './model_0279999_with_sem_new.pth'
|
|
|
227 |
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
|
228 |
|
229 |
with gr.TabItem("Point Mode"):
|
230 |
+
img_state_points = gr.State(value=IMGState())
|
231 |
+
with gr.Row(): # 水平排列
|
232 |
+
with gr.Column(scale=1):
|
233 |
+
input_image = gr.Image( label="Input Image", type="pil")
|
234 |
+
with gr.Column(scale=1): # 第二列:分割图输出
|
235 |
+
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
236 |
+
|
237 |
+
input_image.select(
|
238 |
+
get_points_with_draw,
|
239 |
+
[input_image, img_state_points],
|
240 |
+
outputs=[img_state_points, input_image]
|
241 |
+
).then(
|
242 |
+
inference_point,
|
243 |
+
inputs=[input_image, img_state_points],
|
244 |
+
outputs=[output_image_point]
|
245 |
+
)
|
246 |
+
clear_prompt_button_point = gr.Button("Clean Prompt")
|
247 |
+
clear_prompt_button_point.click(
|
248 |
+
clean_prompts,
|
249 |
+
inputs=[img_state_points],
|
250 |
+
outputs=[img_state_points, input_image, output_image_point]
|
251 |
+
)
|
252 |
+
clear_button_point = gr.Button("Restart")
|
253 |
+
clear_button_point.click(
|
254 |
+
clear_everything,
|
255 |
+
inputs=[img_state_points],
|
256 |
+
outputs=[img_state_points, input_image, output_image_point]
|
257 |
+
)
|
258 |
+
|
259 |
|
260 |
|
261 |
|