|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
|
|
|
if device.type == "cuda": |
|
|
|
|
|
torch.autocast("cuda", dtype=torch.bfloat16).__enter__() |
|
|
|
|
|
if torch.cuda.get_device_properties(0).major >= 8: |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
np.random.seed(3) |
|
|
|
|
|
def show_anns(anns, borders=True): |
|
|
if len(anns) == 0: |
|
|
return |
|
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
|
|
ax = plt.gca() |
|
|
ax.set_autoscale_on(False) |
|
|
|
|
|
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) |
|
|
img[:, :, 3] = 0 |
|
|
for ann in sorted_anns: |
|
|
m = ann['segmentation'] |
|
|
color_mask = np.concatenate([np.random.random(3), [0.5]]) |
|
|
img[m] = color_mask |
|
|
if borders: |
|
|
import cv2 |
|
|
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
|
|
|
|
|
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
|
|
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) |
|
|
|
|
|
ax.imshow(img) |
|
|
|
|
|
|
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
|
|
|
|
sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt" |
|
|
model_cfg = "configs/sam2/sam2_hiera_b+.yaml" |
|
|
|
|
|
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) |
|
|
|
|
|
mask_generator = SAM2AutomaticMaskGenerator(sam2) |
|
|
|
|
|
|
|
|
image = Image.open('/home/yuqian_fu/Projects/sam2/DSCF0669.JPG') |
|
|
image = np.array(image.convert("RGB")) |
|
|
masks = mask_generator.generate(image) |
|
|
|
|
|
save_path = "/home/yuqian_fu/Projects/sam2/results" |
|
|
i = 8 |
|
|
save_path = os.path.join(save_path, str(i) + ".png") |
|
|
ann = masks[1]["segmentation"] |
|
|
binary_mask = (ann.astype(np.uint8)) * 255 |
|
|
cv2.imwrite(save_path, binary_mask) |
|
|
|
|
|
|
|
|
|