jens commited on
Commit
9e6e225
·
1 Parent(s): 0579ca3
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -36,7 +36,16 @@ with block:
36
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
37
  sam_encode_btn = gr.Button('Encode', variant='primary')
38
  sam_encode_status = gr.Label('Not encoded yet')
39
- prompt_image = gr.Image(label='Segments')
 
 
 
 
 
 
 
 
 
40
  with gr.Row():
41
  with gr.Column():
42
  pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
@@ -71,24 +80,25 @@ with block:
71
  input_image.upload(on_input_image_upload, [input_image], [input_image, point_coords, point_labels], queue=False)
72
 
73
  # event - set coords
74
- def on_prompt_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
75
  x, y = evt.index
76
  color = red if point_label_radio == 0 else blue
77
- img = input_image.copy()
78
- img = np.array(img)
79
- cv2.circle(img, (x, y), 5, color, -1)
 
80
  point_coords.append([x,y])
81
  point_labels.append(point_label_radio)
82
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
83
- img[generated_mask] = (255.0, 0.0, 0.0)
84
- img = Image.fromarray(img)
85
- return [ img,
86
  point_coords,
87
  point_labels ]
88
-
89
  prompt_image.select(on_prompt_image_select,
90
  [input_image, point_coords, point_labels, point_label_radio],
91
- [prompt_image, point_coords, point_labels], queue=False)
92
 
93
  def on_click_sam_encode_btn(inputs):
94
  print("encoding")
 
36
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
37
  sam_encode_btn = gr.Button('Encode', variant='primary')
38
  sam_encode_status = gr.Label('Not encoded yet')
39
+ with gr.Row():
40
+ with gr.Tab("Select with points"):
41
+ with gr.Column():
42
+ prompt_image = gr.Image(label='Segments')
43
+ prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
44
+
45
+
46
+ with gr.Tab("Select from segmented map"):
47
+ everything_image = gr.AnnotatedImage(label='Everything')
48
+
49
  with gr.Row():
50
  with gr.Column():
51
  pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
 
80
  input_image.upload(on_input_image_upload, [input_image], [input_image, point_coords, point_labels], queue=False)
81
 
82
  # event - set coords
83
+ def on_prompt_image_select(input_image, prompt, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
84
  x, y = evt.index
85
  color = red if point_label_radio == 0 else blue
86
+ if prompt is None:
87
+ prompt = np.array(input_image.copy())
88
+
89
+ cv2.circle(prompt, (x, y), 5, color, -1)
90
  point_coords.append([x,y])
91
  point_labels.append(point_label_radio)
92
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
93
+
94
+ return [ prompt,
95
+ (input_image, [(generated_mask, "Mask")]),
96
  point_coords,
97
  point_labels ]
98
+
99
  prompt_image.select(on_prompt_image_select,
100
  [input_image, point_coords, point_labels, point_label_radio],
101
+ [prompt_image, prompt_lbl_image, point_coords, point_labels], queue=False)
102
 
103
  def on_click_sam_encode_btn(inputs):
104
  print("encoding")