Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
·
002d880
1
Parent(s):
7036f3f
Add `cb_multimask_output` and fix result type
Browse files- modules/sam_inference.py +19 -18
modules/sam_inference.py
CHANGED
|
@@ -103,25 +103,27 @@ class SamInference:
|
|
| 103 |
output_file_name = f"result-{timestamp}.psd"
|
| 104 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
if input_mode == AUTOMATIC_MODE:
|
| 107 |
image = image_input
|
| 108 |
-
maskgen_hparams = {
|
| 109 |
-
'points_per_side': int(params[0]),
|
| 110 |
-
'points_per_batch': int(params[1]),
|
| 111 |
-
'pred_iou_thresh': float(params[2]),
|
| 112 |
-
'stability_score_thresh': float(params[3]),
|
| 113 |
-
'stability_score_offset': float(params[4]),
|
| 114 |
-
'crop_n_layers': int(params[5]),
|
| 115 |
-
'box_nms_thresh': float(params[6]),
|
| 116 |
-
'crop_n_points_downscale_factor': int(params[7]),
|
| 117 |
-
'min_mask_region_area': int(params[8]),
|
| 118 |
-
'use_m2m': bool(params[9])
|
| 119 |
-
}
|
| 120 |
|
| 121 |
generated_masks = self.generate_mask(
|
| 122 |
image=image,
|
| 123 |
model_type=model_type,
|
| 124 |
-
**
|
| 125 |
)
|
| 126 |
|
| 127 |
elif input_mode == BOX_PROMPT_MODE:
|
|
@@ -129,15 +131,12 @@ class SamInference:
|
|
| 129 |
image = np.array(image.convert("RGB"))
|
| 130 |
box = image_prompt_input_data["points"]
|
| 131 |
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in box])
|
| 132 |
-
predict_image_hparams = {
|
| 133 |
-
"multimask_output": params[0]
|
| 134 |
-
}
|
| 135 |
|
| 136 |
predicted_masks, scores, logits = self.predict_image(
|
| 137 |
image=image,
|
| 138 |
model_type=model_type,
|
| 139 |
box=box,
|
| 140 |
-
|
| 141 |
)
|
| 142 |
generated_masks = self.format_to_auto_result(predicted_masks)
|
| 143 |
|
|
@@ -152,5 +151,7 @@ class SamInference:
|
|
| 152 |
masks: np.ndarray
|
| 153 |
):
|
| 154 |
place_holder = 0
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
return result
|
|
|
|
| 103 |
output_file_name = f"result-{timestamp}.psd"
|
| 104 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 105 |
|
| 106 |
+
hparams = {
|
| 107 |
+
'points_per_side': int(params[0]),
|
| 108 |
+
'points_per_batch': int(params[1]),
|
| 109 |
+
'pred_iou_thresh': float(params[2]),
|
| 110 |
+
'stability_score_thresh': float(params[3]),
|
| 111 |
+
'stability_score_offset': float(params[4]),
|
| 112 |
+
'crop_n_layers': int(params[5]),
|
| 113 |
+
'box_nms_thresh': float(params[6]),
|
| 114 |
+
'crop_n_points_downscale_factor': int(params[7]),
|
| 115 |
+
'min_mask_region_area': int(params[8]),
|
| 116 |
+
'use_m2m': bool(params[9]),
|
| 117 |
+
'multimask_output': bool(params[10])
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
if input_mode == AUTOMATIC_MODE:
|
| 121 |
image = image_input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
generated_masks = self.generate_mask(
|
| 124 |
image=image,
|
| 125 |
model_type=model_type,
|
| 126 |
+
**hparams
|
| 127 |
)
|
| 128 |
|
| 129 |
elif input_mode == BOX_PROMPT_MODE:
|
|
|
|
| 131 |
image = np.array(image.convert("RGB"))
|
| 132 |
box = image_prompt_input_data["points"]
|
| 133 |
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in box])
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
predicted_masks, scores, logits = self.predict_image(
|
| 136 |
image=image,
|
| 137 |
model_type=model_type,
|
| 138 |
box=box,
|
| 139 |
+
multimask_output=hparams["multimask_output"]
|
| 140 |
)
|
| 141 |
generated_masks = self.format_to_auto_result(predicted_masks)
|
| 142 |
|
|
|
|
| 151 |
masks: np.ndarray
|
| 152 |
):
|
| 153 |
place_holder = 0
|
| 154 |
+
if len(masks) == 1:
|
| 155 |
+
return [{"segmentation": mask, "area": place_holder} for mask in masks]
|
| 156 |
+
result = [{"segmentation": mask[0], "area": place_holder} for mask in masks]
|
| 157 |
return result
|