s194649 commited on
Commit
640f5b4
·
1 Parent(s): ae1d3bb

encoder decoder setup

Browse files
Files changed (2) hide show
  1. app.py +9 -7
  2. inference.py +70 -6
app.py CHANGED
@@ -10,6 +10,7 @@ from utils import generate_PCL, PCL3, point_cloud
10
 
11
 
12
  sam = SegmentPredictor()
 
13
  dpt = DepthPredictor()
14
  red = (255,0,0)
15
  blue = (0,0,255)
@@ -30,6 +31,7 @@ with block:
30
  cutout_idx = gr.State(set())
31
  pred_masks = gr.State([])
32
  prompt_masks = gr.State([])
 
33
 
34
  # UI
35
  with gr.Column():
@@ -73,7 +75,7 @@ with block:
73
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
74
 
75
  # components
76
- components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
77
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
78
  sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image}
79
 
@@ -88,7 +90,7 @@ with block:
88
  return input_image, point_coords_empty(), point_labels_empty(), None, []
89
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
90
 
91
- def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData):
92
  x, y = evt.index
93
  color = red if point_label_radio == 0 else blue
94
  if prompt_image is None:
@@ -97,7 +99,7 @@ with block:
97
  cv2.circle(prompt_image, (x, y), 5, color, -1)
98
  point_coords.append([x,y])
99
  point_labels.append(point_label_radio)
100
- sam_masks = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
101
  return [ prompt_image,
102
  (input_image, sam_masks),
103
  point_coords,
@@ -105,7 +107,7 @@ with block:
105
  sam_masks ]
106
 
107
  prompt_image.select(on_prompt_image_select,
108
- [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
109
  [prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
110
 
111
 
@@ -139,10 +141,10 @@ with block:
139
  def on_click_sam_encode_btn(inputs):
140
  print("encoding")
141
  # encode image on click
142
- sam.encode(inputs[input_image])
143
  print("encoding done")
144
- return {prompt_image: inputs[input_image]}
145
- sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image], queue=False)
146
 
147
  def on_click_sam_dencode_btn(inputs):
148
  print("inferencing")
 
10
 
11
 
12
  sam = SegmentPredictor()
13
+ sam_cpu = SegmentPredictor(device='cpu')
14
  dpt = DepthPredictor()
15
  red = (255,0,0)
16
  blue = (0,0,255)
 
31
  cutout_idx = gr.State(set())
32
  pred_masks = gr.State([])
33
  prompt_masks = gr.State([])
34
+ embedding = gr.State()
35
 
36
  # UI
37
  with gr.Column():
 
75
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
76
 
77
  # components
78
+ components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image, embedding,
79
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
80
  sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, max_depth, min_depth, cube_size, selected_masks_image}
81
 
 
90
  return input_image, point_coords_empty(), point_labels_empty(), None, []
91
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
92
 
93
+ def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding, evt: gr.SelectData):
94
  x, y = evt.index
95
  color = red if point_label_radio == 0 else blue
96
  if prompt_image is None:
 
99
  cv2.circle(prompt_image, (x, y), 5, color, -1)
100
  point_coords.append([x,y])
101
  point_labels.append(point_label_radio)
102
+ sam_masks = sam_cpu.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding)
103
  return [ prompt_image,
104
  (input_image, sam_masks),
105
  point_coords,
 
107
  sam_masks ]
108
 
109
  prompt_image.select(on_prompt_image_select,
110
+ [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, embedding],
111
  [prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
112
 
113
 
 
141
  def on_click_sam_encode_btn(inputs):
142
  print("encoding")
143
  # encode image on click
144
+ embedding = sam.encode(inputs[input_image]).cpu()
145
  print("encoding done")
146
+ return [inputs[input_image], embedding]
147
+ sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False)
148
 
149
  def on_click_sam_dencode_btn(inputs):
150
  print("inferencing")
inference.py CHANGED
@@ -11,6 +11,10 @@ import pandas as pd
11
  import plotly.express as px
12
  import matplotlib.pyplot as plt
13
 
 
 
 
 
14
  def map_image_range(image, min_value, max_value):
15
  """
16
  Maps the values of a numpy image array to a specified range.
@@ -188,26 +192,86 @@ class DepthPredictor:
188
 
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  class SegmentPredictor:
193
- def __init__(self):
194
  MODEL_TYPE = "vit_h"
195
  checkpoint = "sam_vit_h_4b8939.pth"
196
  sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
197
  # Select device
198
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
199
  sam.to(device=self.device)
200
  self.mask_generator = SamAutomaticMaskGenerator(sam)
201
- self.conditioned_pred = SamPredictor(sam)
202
 
203
  def encode(self, image):
204
  image = np.array(image)
205
- self.conditioned_pred.set_image(image)
206
 
207
- def cond_pred(self, pts, lbls):
208
  lbls = np.array(lbls)
209
  pts = np.array(pts)
210
- masks, _, _ = self.conditioned_pred.predict(
 
211
  point_coords=pts,
212
  point_labels=lbls,
213
  multimask_output=True
 
11
  import plotly.express as px
12
  import matplotlib.pyplot as plt
13
 
14
+
15
+
16
+
17
+
18
  def map_image_range(image, min_value, max_value):
19
  """
20
  Maps the values of a numpy image array to a specified range.
 
192
 
193
 
194
 
195
+ import numpy as np
196
+ from typing import Optional, Tuple
197
+
198
+ class CustomSamPredictor(SamPredictor):
199
+ def __init__(
200
+ self,
201
+ sam_model,
202
+ ) -> None:
203
+ super().__init__(sam_model)
204
+
205
+ def encode_image(self, image: np.ndarray, image_format: str = "RGB") -> torch.Tensor:
206
+ """
207
+ Encodes the image and returns its embedding.
208
+
209
+ Arguments:
210
+ image (np.ndarray): The image for which to calculate the embedding.
211
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
212
+
213
+ Returns:
214
+ torch.Tensor: The image embedding with shape 1xCxHxW.
215
+ """
216
+ self.set_image(image, image_format)
217
+ return self.get_image_embedding()
218
+
219
+ def decode_and_predict(
220
+ self,
221
+ embedding: torch.Tensor,
222
+ point_coords: Optional[np.ndarray] = None,
223
+ point_labels: Optional[np.ndarray] = None,
224
+ box: Optional[np.ndarray] = None,
225
+ mask_input: Optional[np.ndarray] = None,
226
+ multimask_output: bool = True,
227
+ return_logits: bool = False,
228
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
229
+ """
230
+ Decodes the provided image embedding and makes mask predictions based on prompts.
231
+
232
+ Arguments:
233
+ embedding (torch.Tensor): The image embedding to decode.
234
+ ... (other arguments from the predict function)
235
+
236
+ Returns:
237
+ (np.ndarray): The output masks in CxHxW format.
238
+ (np.ndarray): An array of quality predictions for each mask.
239
+ (np.ndarray): Low resolution mask logits for subsequent iterations.
240
+ """
241
+ self.set_torch_image(embedding, (embedding.shape[-2], embedding.shape[-1]))
242
+ return self.predict(
243
+ point_coords=point_coords,
244
+ point_labels=point_labels,
245
+ box=box,
246
+ mask_input=mask_input,
247
+ multimask_output=multimask_output,
248
+ return_logits=return_logits,
249
+ )
250
+
251
 
252
  class SegmentPredictor:
253
+ def __init__(self, device=None):
254
  MODEL_TYPE = "vit_h"
255
  checkpoint = "sam_vit_h_4b8939.pth"
256
  sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
257
  # Select device
258
+ if device is None:
259
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
260
+ else:
261
+ self.device = device
262
  sam.to(device=self.device)
263
  self.mask_generator = SamAutomaticMaskGenerator(sam)
264
+ self.conditioned_pred = CustomSamPredictor(sam)
265
 
266
  def encode(self, image):
267
  image = np.array(image)
268
+ return self.encode_image(image)
269
 
270
+ def cond_pred(self, embedding, pts, lbls):
271
  lbls = np.array(lbls)
272
  pts = np.array(pts)
273
+ masks, _, _ = self.conditioned_pred.decode_and_predict(
274
+ embedding,
275
  point_coords=pts,
276
  point_labels=lbls,
277
  multimask_output=True