|
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits |
|
|
|
|
|
from PIL import Image |
|
from torchvision import transforms |
|
import torch |
|
import numpy as np |
|
import zipfile |
|
|
|
|
|
|
|
models = {} |
|
|
|
|
|
models['efficientsam_ti'] = build_efficient_sam_vitt() |
|
|
|
|
|
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref: |
|
zip_ref.extractall("weights") |
|
|
|
models['efficientsam_s'] = build_efficient_sam_vits() |
|
|
|
|
|
|
|
|
|
|
|
sample_image_np = np.array(Image.open("figs/examples/dogs.jpg")) |
|
sample_image_tensor = transforms.ToTensor()(sample_image_np) |
|
|
|
|
|
input_points = torch.tensor([[[[580, 350], [650, 350]]]]) |
|
input_labels = torch.tensor([[[1, 1]]]) |
|
|
|
|
|
for model_name, model in models.items(): |
|
print('Running inference using ', model_name) |
|
predicted_logits, predicted_iou = model( |
|
sample_image_tensor[None, ...], |
|
input_points, |
|
input_labels, |
|
) |
|
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True) |
|
predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2) |
|
predicted_logits = torch.take_along_dim( |
|
predicted_logits, sorted_ids[..., None, None], dim=2 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy() |
|
masked_image_np = sample_image_np.copy().astype(np.uint8) * mask[:,:,None] |
|
Image.fromarray(masked_image_np).save(f"figs/examples/dogs_{model_name}_mask.png") |
|
|