sam2 / mask_generation.py
YuqianFu's picture
Upload folder using huggingface_hub
1867b21 verified
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2
if torch.cuda.is_available():
device = torch.device("cuda")
if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
np.random.seed(3)
def show_anns(anns, borders=True):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.5]])
img[m] = color_mask
if borders:
import cv2
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)
ax.imshow(img)
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
model_cfg = "configs/sam2/sam2_hiera_b+.yaml"
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
image = Image.open('/home/yuqian_fu/Projects/sam2/DSCF0669.JPG')
image = np.array(image.convert("RGB"))
masks = mask_generator.generate(image)
save_path = "/home/yuqian_fu/Projects/sam2/results"
i = 8
save_path = os.path.join(save_path, str(i) + ".png")
ann = masks[1]["segmentation"]
binary_mask = (ann.astype(np.uint8)) * 255
cv2.imwrite(save_path, binary_mask)