Spaces:
Runtime error
Runtime error
s194649
commited on
Commit
·
640f5b4
1
Parent(s):
ae1d3bb
encoder decoder setup
Browse files- app.py +9 -7
- inference.py +70 -6
app.py
CHANGED
@@ -10,6 +10,7 @@ from utils import generate_PCL, PCL3, point_cloud
|
|
10 |
|
11 |
|
12 |
sam = SegmentPredictor()
|
|
|
13 |
dpt = DepthPredictor()
|
14 |
red = (255,0,0)
|
15 |
blue = (0,0,255)
|
@@ -30,6 +31,7 @@ with block:
|
|
30 |
cutout_idx = gr.State(set())
|
31 |
pred_masks = gr.State([])
|
32 |
prompt_masks = gr.State([])
|
|
|
33 |
|
34 |
# UI
|
35 |
with gr.Column():
|
@@ -73,7 +75,7 @@ with block:
|
|
73 |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
|
74 |
|
75 |
# components
|
76 |
-
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
|
77 |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
|
78 |
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image}
|
79 |
|
@@ -88,7 +90,7 @@ with block:
|
|
88 |
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
89 |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
|
90 |
|
91 |
-
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData):
|
92 |
x, y = evt.index
|
93 |
color = red if point_label_radio == 0 else blue
|
94 |
if prompt_image is None:
|
@@ -97,7 +99,7 @@ with block:
|
|
97 |
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
98 |
point_coords.append([x,y])
|
99 |
point_labels.append(point_label_radio)
|
100 |
-
sam_masks =
|
101 |
return [ prompt_image,
|
102 |
(input_image, sam_masks),
|
103 |
point_coords,
|
@@ -105,7 +107,7 @@ with block:
|
|
105 |
sam_masks ]
|
106 |
|
107 |
prompt_image.select(on_prompt_image_select,
|
108 |
-
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
|
109 |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
|
110 |
|
111 |
|
@@ -139,10 +141,10 @@ with block:
|
|
139 |
def on_click_sam_encode_btn(inputs):
|
140 |
print("encoding")
|
141 |
# encode image on click
|
142 |
-
sam.encode(inputs[input_image])
|
143 |
print("encoding done")
|
144 |
-
return
|
145 |
-
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image], queue=False)
|
146 |
|
147 |
def on_click_sam_dencode_btn(inputs):
|
148 |
print("inferencing")
|
|
|
10 |
|
11 |
|
12 |
sam = SegmentPredictor()
|
13 |
+
sam_cpu = SegmentPredictor(device='cpu')
|
14 |
dpt = DepthPredictor()
|
15 |
red = (255,0,0)
|
16 |
blue = (0,0,255)
|
|
|
31 |
cutout_idx = gr.State(set())
|
32 |
pred_masks = gr.State([])
|
33 |
prompt_masks = gr.State([])
|
34 |
+
embedding = gr.State()
|
35 |
|
36 |
# UI
|
37 |
with gr.Column():
|
|
|
75 |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
|
76 |
|
77 |
# components
|
78 |
+
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image, embedding,
|
79 |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
|
80 |
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image}
|
81 |
|
|
|
90 |
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
91 |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
|
92 |
|
93 |
+
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding, evt: gr.SelectData):
|
94 |
x, y = evt.index
|
95 |
color = red if point_label_radio == 0 else blue
|
96 |
if prompt_image is None:
|
|
|
99 |
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
100 |
point_coords.append([x,y])
|
101 |
point_labels.append(point_label_radio)
|
102 |
+
sam_masks = sam_cpu.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding)
|
103 |
return [ prompt_image,
|
104 |
(input_image, sam_masks),
|
105 |
point_coords,
|
|
|
107 |
sam_masks ]
|
108 |
|
109 |
prompt_image.select(on_prompt_image_select,
|
110 |
+
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding],
|
111 |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
|
112 |
|
113 |
|
|
|
141 |
def on_click_sam_encode_btn(inputs):
|
142 |
print("encoding")
|
143 |
# encode image on click
|
144 |
+
embedding = sam.encode(inputs[input_image]).cpu()
|
145 |
print("encoding done")
|
146 |
+
return [inputs[input_image], embedding]
|
147 |
+
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False)
|
148 |
|
149 |
def on_click_sam_dencode_btn(inputs):
|
150 |
print("inferencing")
|
inference.py
CHANGED
@@ -11,6 +11,10 @@ import pandas as pd
|
|
11 |
import plotly.express as px
|
12 |
import matplotlib.pyplot as plt
|
13 |
|
|
|
|
|
|
|
|
|
14 |
def map_image_range(image, min_value, max_value):
|
15 |
"""
|
16 |
Maps the values of a numpy image array to a specified range.
|
@@ -188,26 +192,86 @@ class DepthPredictor:
|
|
188 |
|
189 |
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
class SegmentPredictor:
|
193 |
-
def __init__(self):
|
194 |
MODEL_TYPE = "vit_h"
|
195 |
checkpoint = "sam_vit_h_4b8939.pth"
|
196 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
|
197 |
# Select device
|
198 |
-
|
|
|
|
|
|
|
199 |
sam.to(device=self.device)
|
200 |
self.mask_generator = SamAutomaticMaskGenerator(sam)
|
201 |
-
self.conditioned_pred =
|
202 |
|
203 |
def encode(self, image):
|
204 |
image = np.array(image)
|
205 |
-
self.
|
206 |
|
207 |
-
def cond_pred(self, pts, lbls):
|
208 |
lbls = np.array(lbls)
|
209 |
pts = np.array(pts)
|
210 |
-
masks, _, _ = self.conditioned_pred.
|
|
|
211 |
point_coords=pts,
|
212 |
point_labels=lbls,
|
213 |
multimask_output=True
|
|
|
11 |
import plotly.express as px
|
12 |
import matplotlib.pyplot as plt
|
13 |
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
def map_image_range(image, min_value, max_value):
|
19 |
"""
|
20 |
Maps the values of a numpy image array to a specified range.
|
|
|
192 |
|
193 |
|
194 |
|
195 |
+
import numpy as np
|
196 |
+
from typing import Optional, Tuple
|
197 |
+
|
198 |
+
class CustomSamPredictor(SamPredictor):
|
199 |
+
def __init__(
|
200 |
+
self,
|
201 |
+
sam_model,
|
202 |
+
) -> None:
|
203 |
+
super().__init__(sam_model)
|
204 |
+
|
205 |
+
def encode_image(self, image: np.ndarray, image_format: str = "RGB") -> torch.Tensor:
|
206 |
+
"""
|
207 |
+
Encodes the image and returns its embedding.
|
208 |
+
|
209 |
+
Arguments:
|
210 |
+
image (np.ndarray): The image for which to calculate the embedding.
|
211 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
torch.Tensor: The image embedding with shape 1xCxHxW.
|
215 |
+
"""
|
216 |
+
self.set_image(image, image_format)
|
217 |
+
return self.get_image_embedding()
|
218 |
+
|
219 |
+
def decode_and_predict(
|
220 |
+
self,
|
221 |
+
embedding: torch.Tensor,
|
222 |
+
point_coords: Optional[np.ndarray] = None,
|
223 |
+
point_labels: Optional[np.ndarray] = None,
|
224 |
+
box: Optional[np.ndarray] = None,
|
225 |
+
mask_input: Optional[np.ndarray] = None,
|
226 |
+
multimask_output: bool = True,
|
227 |
+
return_logits: bool = False,
|
228 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
229 |
+
"""
|
230 |
+
Decodes the provided image embedding and makes mask predictions based on prompts.
|
231 |
+
|
232 |
+
Arguments:
|
233 |
+
embedding (torch.Tensor): The image embedding to decode.
|
234 |
+
... (other arguments from the predict function)
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
(np.ndarray): The output masks in CxHxW format.
|
238 |
+
(np.ndarray): An array of quality predictions for each mask.
|
239 |
+
(np.ndarray): Low resolution mask logits for subsequent iterations.
|
240 |
+
"""
|
241 |
+
self.set_torch_image(embedding, (embedding.shape[-2], embedding.shape[-1]))
|
242 |
+
return self.predict(
|
243 |
+
point_coords=point_coords,
|
244 |
+
point_labels=point_labels,
|
245 |
+
box=box,
|
246 |
+
mask_input=mask_input,
|
247 |
+
multimask_output=multimask_output,
|
248 |
+
return_logits=return_logits,
|
249 |
+
)
|
250 |
+
|
251 |
|
252 |
class SegmentPredictor:
|
253 |
+
def __init__(self, device=None):
|
254 |
MODEL_TYPE = "vit_h"
|
255 |
checkpoint = "sam_vit_h_4b8939.pth"
|
256 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
|
257 |
# Select device
|
258 |
+
if device is None:
|
259 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
260 |
+
else:
|
261 |
+
self.device = device
|
262 |
sam.to(device=self.device)
|
263 |
self.mask_generator = SamAutomaticMaskGenerator(sam)
|
264 |
+
self.conditioned_pred = CustomSamPredictor(sam)
|
265 |
|
266 |
def encode(self, image):
|
267 |
image = np.array(image)
|
268 |
+
return self.encode_image(image)
|
269 |
|
270 |
+
def cond_pred(self, embedding, pts, lbls):
|
271 |
lbls = np.array(lbls)
|
272 |
pts = np.array(pts)
|
273 |
+
masks, _, _ = self.conditioned_pred.decode_and_predict(
|
274 |
+
embedding,
|
275 |
point_coords=pts,
|
276 |
point_labels=lbls,
|
277 |
multimask_output=True
|