jens commited on
Commit
f76bf44
·
1 Parent(s): 2d1b836
Files changed (1) hide show
  1. app.py +8 -18
app.py CHANGED
@@ -70,34 +70,24 @@ with block:
70
  input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
71
 
72
  # event - set coords
73
- def on_input_image_select(raw_image, input_image, image_edit_trigger, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
74
-
75
- if image_edit_trigger:
76
- unedited_image = input_image.copy()
77
- image_edit_trigger = False
78
- else:
79
- unedited_image = raw_image
80
-
81
  x, y = evt.index
82
  color = red if point_label_radio == 0 else blue
83
- img = np.array(input_image)
 
84
  cv2.circle(img, (x, y), 5, color, -1)
85
  point_coords.append([x,y])
86
  point_labels.append(point_label_radio)
87
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
88
- img[generated_mask] = (1.0, 0.0, 0.0)
89
  img = Image.fromarray(img)
90
-
91
  return [ img,
92
- unedited_image,
93
- img,
94
  point_coords,
95
- point_labels,
96
- image_edit_trigger]
97
 
98
- prompt_image.select(on_input_image_select,
99
- [raw_image, input_image, image_edit_trigger, point_coords, point_labels, point_label_radio],
100
- [prompt_image, raw_image, input_image, point_coords, point_labels, image_edit_trigger], queue=False)
101
 
102
  def on_click_sam_encode_btn(inputs):
103
  print("encoding")
 
70
  input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
71
 
72
  # event - set coords
73
+ def on_prompt_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
 
 
 
 
 
 
 
74
  x, y = evt.index
75
  color = red if point_label_radio == 0 else blue
76
+ img = input_image.copy()
77
+ img = np.array(img)
78
  cv2.circle(img, (x, y), 5, color, -1)
79
  point_coords.append([x,y])
80
  point_labels.append(point_label_radio)
81
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
82
+ img[generated_mask] = (255.0, 0.0, 0.0)
83
  img = Image.fromarray(img)
 
84
  return [ img,
 
 
85
  point_coords,
86
+ point_labels ]
 
87
 
88
+ prompt_image.select(on_prompt_image_select,
89
+ [input_image, point_coords, point_labels, point_label_radio],
90
+ [prompt_image, point_coords, point_labels], queue=False)
91
 
92
  def on_click_sam_encode_btn(inputs):
93
  print("encoding")