jhj0517 commited on
Commit
63be95c
·
1 Parent(s): 13ba185

Enable box and points input

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +10 -9
modules/sam_inference.py CHANGED
@@ -144,20 +144,21 @@ class SamInference:
144
  if len(prompt) == 0:
145
  return [image], []
146
 
147
- is_prompt_point = prompt[0][-1] == 4.0
148
 
149
- if is_prompt_point:
150
- point_labels = np.array([1 if is_left_click else 0 for x1, y1, is_left_click, x2, y2, _ in prompt])
151
- prompt = np.array([[x1, y1] for x1, y1, is_left_click, x2, y2, _ in prompt])
152
- else:
153
- prompt = np.array([[x1, y1, x2, y2] for x1, y1, is_left_click, x2, y2, _ in prompt])
 
154
 
155
  predicted_masks, scores, logits = self.predict_image(
156
  image=image,
157
  model_type=model_type,
158
- box=prompt if not is_prompt_point else None,
159
- point_coords=prompt if is_prompt_point else None,
160
- point_labels=point_labels if is_prompt_point else None,
161
  multimask_output=hparams["multimask_output"]
162
  )
163
  generated_masks = self.format_to_auto_result(predicted_masks)
 
144
  if len(prompt) == 0:
145
  return [image], []
146
 
147
+ point_labels, point_coords, box = [], [], []
148
 
149
+ for x1, y1, left_click_indicator, x2, y2, point_indicator in prompt:
150
+ if point_indicator == 4.0:
151
+ point_labels.append(left_click_indicator)
152
+ point_coords.append([x1, y1])
153
+ else:
154
+ box.append([x1, y1, x2, y2])
155
 
156
  predicted_masks, scores, logits = self.predict_image(
157
  image=image,
158
  model_type=model_type,
159
+ box=np.array(box) if box else None,
160
+ point_coords=np.array(point_coords) if point_coords else None,
161
+ point_labels=np.array(point_labels) if point_labels else None,
162
  multimask_output=hparams["multimask_output"]
163
  )
164
  generated_masks = self.format_to_auto_result(predicted_masks)