Spaces:
Runtime error
Runtime error
Commit
·
637f41e
1
Parent(s):
699405e
update script
Browse files- annotate_anything.py +22 -16
annotate_anything.py
CHANGED
|
@@ -33,8 +33,9 @@ def process(
|
|
| 33 |
box_threshold,
|
| 34 |
text_threshold,
|
| 35 |
iou_threshold,
|
| 36 |
-
device,
|
| 37 |
output_dir=None,
|
|
|
|
| 38 |
save_mask=False,
|
| 39 |
):
|
| 40 |
detections = None
|
|
@@ -84,7 +85,7 @@ def process(
|
|
| 84 |
)
|
| 85 |
|
| 86 |
# Save detection image
|
| 87 |
-
if output_dir:
|
| 88 |
# Draw boxes
|
| 89 |
box_annotator = sv.BoxAnnotator()
|
| 90 |
labels = [
|
|
@@ -123,7 +124,7 @@ def process(
|
|
| 123 |
)
|
| 124 |
|
| 125 |
# Save annotated image
|
| 126 |
-
if output_dir:
|
| 127 |
mask_annotator = sv.MaskAnnotator()
|
| 128 |
mask_image, res = show_anns_sv(detections)
|
| 129 |
annotated_image = mask_annotator.annotate(image, detections=detections)
|
|
@@ -197,12 +198,13 @@ def main(args: argparse.Namespace) -> None:
|
|
| 197 |
box_threshold = args.box_threshold
|
| 198 |
text_threshold = args.text_threshold
|
| 199 |
iou_threshold = args.iou_threshold
|
|
|
|
| 200 |
save_mask = args.save_mask
|
| 201 |
|
| 202 |
# load model
|
| 203 |
if task in ["auto", "detection"] and prompt == "":
|
| 204 |
print("Loading Tag2Text model...")
|
| 205 |
-
tag2text_type = args.
|
| 206 |
tag2text_checkpoint = os.path.join(
|
| 207 |
abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
|
| 208 |
)
|
|
@@ -225,7 +227,7 @@ def main(args: argparse.Namespace) -> None:
|
|
| 225 |
|
| 226 |
if task in ["auto", "detection"] or prompt != "":
|
| 227 |
print("Loading Grounding Dino model...")
|
| 228 |
-
dino_type = args.
|
| 229 |
dino_checkpoint = os.path.join(
|
| 230 |
abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
|
| 231 |
)
|
|
@@ -253,7 +255,7 @@ def main(args: argparse.Namespace) -> None:
|
|
| 253 |
|
| 254 |
if task in ["auto", "segment"]:
|
| 255 |
print("Loading SAM...")
|
| 256 |
-
sam_type = args.
|
| 257 |
sam_checkpoint = os.path.join(
|
| 258 |
abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
|
| 259 |
)
|
|
@@ -292,6 +294,7 @@ def main(args: argparse.Namespace) -> None:
|
|
| 292 |
iou_threshold=iou_threshold,
|
| 293 |
device=device,
|
| 294 |
output_dir=args.output,
|
|
|
|
| 295 |
save_mask=save_mask,
|
| 296 |
)
|
| 297 |
|
|
@@ -319,34 +322,31 @@ if __name__ == "__main__":
|
|
| 319 |
"-o",
|
| 320 |
type=str,
|
| 321 |
required=True,
|
| 322 |
-
help=
|
| 323 |
-
"Path to the directory where masks will be output. Output will be either a folder "
|
| 324 |
-
"of PNGs per image or a single json with COCO-style masks."
|
| 325 |
-
),
|
| 326 |
)
|
| 327 |
|
| 328 |
parser.add_argument(
|
| 329 |
-
"--sam",
|
| 330 |
type=str,
|
| 331 |
default=default_sam,
|
| 332 |
choices=sam_dict.keys(),
|
| 333 |
-
help="The type of SA model
|
| 334 |
)
|
| 335 |
|
| 336 |
parser.add_argument(
|
| 337 |
-
"--tag2text",
|
| 338 |
type=str,
|
| 339 |
default=default_tag2text,
|
| 340 |
choices=tag2text_dict.keys(),
|
| 341 |
-
help="The
|
| 342 |
)
|
| 343 |
|
| 344 |
parser.add_argument(
|
| 345 |
-
"--dino",
|
| 346 |
type=str,
|
| 347 |
default=default_dino,
|
| 348 |
choices=dino_dict.keys(),
|
| 349 |
-
help="The
|
| 350 |
)
|
| 351 |
|
| 352 |
parser.add_argument(
|
|
@@ -373,6 +373,12 @@ if __name__ == "__main__":
|
|
| 373 |
"--iou-threshold", type=float, default=0.5, help="iou threshold"
|
| 374 |
)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
parser.add_argument(
|
| 377 |
"--save-mask",
|
| 378 |
action="store_true",
|
|
|
|
| 33 |
box_threshold,
|
| 34 |
text_threshold,
|
| 35 |
iou_threshold,
|
| 36 |
+
device="cuda",
|
| 37 |
output_dir=None,
|
| 38 |
+
save_ann=True,
|
| 39 |
save_mask=False,
|
| 40 |
):
|
| 41 |
detections = None
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
# Save detection image
|
| 88 |
+
if output_dir and save_ann:
|
| 89 |
# Draw boxes
|
| 90 |
box_annotator = sv.BoxAnnotator()
|
| 91 |
labels = [
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
# Save annotated image
|
| 127 |
+
if output_dir and save_ann:
|
| 128 |
mask_annotator = sv.MaskAnnotator()
|
| 129 |
mask_image, res = show_anns_sv(detections)
|
| 130 |
annotated_image = mask_annotator.annotate(image, detections=detections)
|
|
|
|
| 198 |
box_threshold = args.box_threshold
|
| 199 |
text_threshold = args.text_threshold
|
| 200 |
iou_threshold = args.iou_threshold
|
| 201 |
+
save_ann = not args.no_save_ann
|
| 202 |
save_mask = args.save_mask
|
| 203 |
|
| 204 |
# load model
|
| 205 |
if task in ["auto", "detection"] and prompt == "":
|
| 206 |
print("Loading Tag2Text model...")
|
| 207 |
+
tag2text_type = args.tag2text_type
|
| 208 |
tag2text_checkpoint = os.path.join(
|
| 209 |
abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
|
| 210 |
)
|
|
|
|
| 227 |
|
| 228 |
if task in ["auto", "detection"] or prompt != "":
|
| 229 |
print("Loading Grounding Dino model...")
|
| 230 |
+
dino_type = args.dino_type
|
| 231 |
dino_checkpoint = os.path.join(
|
| 232 |
abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
|
| 233 |
)
|
|
|
|
| 255 |
|
| 256 |
if task in ["auto", "segment"]:
|
| 257 |
print("Loading SAM...")
|
| 258 |
+
sam_type = args.sam_type
|
| 259 |
sam_checkpoint = os.path.join(
|
| 260 |
abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
|
| 261 |
)
|
|
|
|
| 294 |
iou_threshold=iou_threshold,
|
| 295 |
device=device,
|
| 296 |
output_dir=args.output,
|
| 297 |
+
save_ann=save_ann,
|
| 298 |
save_mask=save_mask,
|
| 299 |
)
|
| 300 |
|
|
|
|
| 322 |
"-o",
|
| 323 |
type=str,
|
| 324 |
required=True,
|
| 325 |
+
help="Path to the directory where masks will be output.",
|
|
|
|
|
|
|
|
|
|
| 326 |
)
|
| 327 |
|
| 328 |
parser.add_argument(
|
| 329 |
+
"--sam-type",
|
| 330 |
type=str,
|
| 331 |
default=default_sam,
|
| 332 |
choices=sam_dict.keys(),
|
| 333 |
+
help="The type of SA model use for segmentation.",
|
| 334 |
)
|
| 335 |
|
| 336 |
parser.add_argument(
|
| 337 |
+
"--tag2text-type",
|
| 338 |
type=str,
|
| 339 |
default=default_tag2text,
|
| 340 |
choices=tag2text_dict.keys(),
|
| 341 |
+
help="The type of Tag2Text model use for tags and caption generation.",
|
| 342 |
)
|
| 343 |
|
| 344 |
parser.add_argument(
|
| 345 |
+
"--dino-type",
|
| 346 |
type=str,
|
| 347 |
default=default_dino,
|
| 348 |
choices=dino_dict.keys(),
|
| 349 |
+
help="The type of Grounding Dino model use for promptable object detection.",
|
| 350 |
)
|
| 351 |
|
| 352 |
parser.add_argument(
|
|
|
|
| 373 |
"--iou-threshold", type=float, default=0.5, help="iou threshold"
|
| 374 |
)
|
| 375 |
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--no-save-ann",
|
| 378 |
+
action="store_true",
|
| 379 |
+
default=False,
|
| 380 |
+
help="If False, save original image with blended masks and detection boxes.",
|
| 381 |
+
)
|
| 382 |
parser.add_argument(
|
| 383 |
"--save-mask",
|
| 384 |
action="store_true",
|