Spaces:
Sleeping
Sleeping
Upload 22 files
Browse files- sam_extension/distillation_models/__init__.py +4 -0
- sam_extension/distillation_models/__pycache__/__init__.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/dino.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/fastertinyvit.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/fastervit.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/sam.cpython-38.pyc +0 -0
- sam_extension/distillation_models/dino.py +122 -0
- sam_extension/distillation_models/fastertinyvit.py +233 -0
- sam_extension/distillation_models/fastervit.py +659 -0
- sam_extension/distillation_models/sam.py +369 -0
- sam_extension/pipeline/__init__.py +4 -0
- sam_extension/pipeline/__pycache__/__init__.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/base.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/groundingdino.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/owlvit.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/sam.cpython-38.pyc +0 -0
- sam_extension/pipeline/base.py +20 -0
- sam_extension/pipeline/groundingdino.py +97 -0
- sam_extension/pipeline/owlvit.py +372 -0
- sam_extension/pipeline/sam.py +722 -0
- sam_extension/utils/__init__.py +175 -0
- sam_extension/utils/__pycache__/__init__.cpython-38.pyc +0 -0
sam_extension/distillation_models/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dino import DINO
|
| 2 |
+
from .sam import SAMEncoderViT, DINOSAMViT
|
| 3 |
+
from .fastertinyvit import FasterTinyViT
|
| 4 |
+
# from .flashvision_transformer import FlashVisionTransformer
|
sam_extension/distillation_models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
sam_extension/distillation_models/__pycache__/dino.cpython-38.pyc
ADDED
|
Binary file (4.72 kB). View file
|
|
|
sam_extension/distillation_models/__pycache__/fastertinyvit.cpython-38.pyc
ADDED
|
Binary file (6.26 kB). View file
|
|
|
sam_extension/distillation_models/__pycache__/fastervit.cpython-38.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
sam_extension/distillation_models/__pycache__/sam.cpython-38.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
sam_extension/distillation_models/dino.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PIL
|
| 2 |
+
from PIL.Image import Image
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
from sklearn.decomposition import PCA
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torchvision import transforms as tfs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
MEAN = [0.485, 0.456, 0.406]
|
| 13 |
+
STD = [0.229, 0.224, 0.225]
|
| 14 |
+
DINO_MODEL_HUB = 'facebookresearch/dino:main'
|
| 15 |
+
DINO_MODEL_TYPE = ['dino_vits16',
|
| 16 |
+
'dino_vits8',
|
| 17 |
+
'dino_vitb16',
|
| 18 |
+
'dino_vitb8',
|
| 19 |
+
'dino_xcit_small_12_p16',
|
| 20 |
+
'dino_xcit_small_12_p8',
|
| 21 |
+
'dino_xcit_medium_24_p16',
|
| 22 |
+
'dino_xcit_medium_24_p8',
|
| 23 |
+
'dino_resnet50']
|
| 24 |
+
|
| 25 |
+
DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
|
| 26 |
+
DINOV2_MODEL_TYPE = ['dinov2_vits14',
|
| 27 |
+
'dinov2_vitb14',
|
| 28 |
+
'dinov2_vitl14',
|
| 29 |
+
'dinov2_vitg14']
|
| 30 |
+
|
| 31 |
+
class DINO(nn.Module):
|
| 32 |
+
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
|
| 33 |
+
super(DINO, self).__init__()
|
| 34 |
+
assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
|
| 35 |
+
self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
|
| 36 |
+
self.device = device
|
| 37 |
+
for param in self.model.parameters():
|
| 38 |
+
param.requires_grad = False
|
| 39 |
+
self.model.eval()
|
| 40 |
+
self.img_size = img_size
|
| 41 |
+
self.pca_dim = pca_dim
|
| 42 |
+
self.pca = self.set_pca(pca_dim) if pca_dim else None
|
| 43 |
+
def set_pca(self, dim=64):
|
| 44 |
+
return PCA(n_components=dim)
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def extract_features(
|
| 47 |
+
self, img: Union[Image, torch.Tensor], transform=True, size=None
|
| 48 |
+
):
|
| 49 |
+
if transform and isinstance(img, Image):
|
| 50 |
+
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0]
|
| 53 |
+
out = out[:, 1:, :] # we discard the [CLS] token
|
| 54 |
+
h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
|
| 55 |
+
img.shape[3] / self.model.patch_embed.patch_size
|
| 56 |
+
)
|
| 57 |
+
dim = out.shape[-1]
|
| 58 |
+
out = out.reshape(-1, h, w, dim)
|
| 59 |
+
dtype = out.dtype
|
| 60 |
+
if size is not None:
|
| 61 |
+
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
|
| 62 |
+
if self.pca:
|
| 63 |
+
B, H, W, C = out.shape
|
| 64 |
+
out = out.view(-1, C).cpu().numpy()
|
| 65 |
+
out = self.pca.fit_transform(out)
|
| 66 |
+
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
|
| 67 |
+
return out
|
| 68 |
+
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
|
| 69 |
+
return self.extract_features(img, transform, size)
|
| 70 |
+
@staticmethod
|
| 71 |
+
def transform(img, image_size):
|
| 72 |
+
transforms = tfs.Compose(
|
| 73 |
+
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
|
| 74 |
+
)
|
| 75 |
+
img = transforms(img)
|
| 76 |
+
return img
|
| 77 |
+
|
| 78 |
+
class DINOV2(nn.Module):
|
| 79 |
+
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
|
| 80 |
+
super(DINOV2, self).__init__()
|
| 81 |
+
assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
|
| 82 |
+
self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
|
| 83 |
+
self.device = device
|
| 84 |
+
for param in self.model.parameters():
|
| 85 |
+
param.requires_grad = False
|
| 86 |
+
self.model.eval()
|
| 87 |
+
self.img_size = img_size
|
| 88 |
+
self.pca_dim = pca_dim
|
| 89 |
+
self.pca = self.set_pca(pca_dim) if pca_dim else None
|
| 90 |
+
def set_pca(self, dim=64):
|
| 91 |
+
return PCA(n_components=dim)
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def extract_features(
|
| 94 |
+
self, img: Union[Image, torch.Tensor], transform=True, size=None
|
| 95 |
+
):
|
| 96 |
+
if transform and isinstance(img, Image):
|
| 97 |
+
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens']
|
| 100 |
+
h, w = int(img.shape[2] / self.model.patch_size), int(
|
| 101 |
+
img.shape[3] / self.model.patch_size
|
| 102 |
+
)
|
| 103 |
+
dim = out.shape[-1]
|
| 104 |
+
out = out.reshape(-1, h, w, dim)
|
| 105 |
+
dtype = out.dtype
|
| 106 |
+
if size is not None:
|
| 107 |
+
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
|
| 108 |
+
if self.pca:
|
| 109 |
+
B, H, W, C = out.shape
|
| 110 |
+
out = out.view(-1, C).cpu().numpy()
|
| 111 |
+
out = self.pca.fit_transform(out)
|
| 112 |
+
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
|
| 113 |
+
return out
|
| 114 |
+
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
|
| 115 |
+
return self.extract_features(img, transform, size)
|
| 116 |
+
@staticmethod
|
| 117 |
+
def transform(img, image_size):
|
| 118 |
+
transforms = tfs.Compose(
|
| 119 |
+
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
|
| 120 |
+
)
|
| 121 |
+
img = transforms(img)
|
| 122 |
+
return img
|
sam_extension/distillation_models/fastertinyvit.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Union
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.utils.checkpoint import checkpoint
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from timm.models.layers import trunc_normal_
|
| 7 |
+
from sam_extension.distillation_models.fastervit import FasterViTLayer
|
| 8 |
+
from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv
|
| 9 |
+
class PatchMerging(nn.Module):
|
| 10 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.input_resolution = input_resolution
|
| 14 |
+
self.dim = dim
|
| 15 |
+
self.out_dim = out_dim
|
| 16 |
+
self.act = activation()
|
| 17 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
| 18 |
+
stride_c=2
|
| 19 |
+
if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
|
| 20 |
+
stride_c=1
|
| 21 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
| 22 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
if x.ndim == 3:
|
| 26 |
+
H, W = self.input_resolution
|
| 27 |
+
B = len(x)
|
| 28 |
+
# (B, C, H, W)
|
| 29 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
| 30 |
+
|
| 31 |
+
x = self.conv1(x)
|
| 32 |
+
x = self.act(x)
|
| 33 |
+
|
| 34 |
+
x = self.conv2(x)
|
| 35 |
+
x = self.act(x)
|
| 36 |
+
x = self.conv3(x)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ConvLayer(nn.Module):
|
| 41 |
+
def __init__(self, dim, input_resolution, depth,
|
| 42 |
+
activation,
|
| 43 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
| 44 |
+
out_dim=None,
|
| 45 |
+
conv_expand_ratio=4.,
|
| 46 |
+
):
|
| 47 |
+
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.dim = dim
|
| 50 |
+
self.input_resolution = input_resolution
|
| 51 |
+
self.depth = depth
|
| 52 |
+
self.use_checkpoint = use_checkpoint
|
| 53 |
+
|
| 54 |
+
# build blocks
|
| 55 |
+
self.blocks = nn.ModuleList([
|
| 56 |
+
MBConv(dim, dim, conv_expand_ratio, activation,
|
| 57 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 58 |
+
)
|
| 59 |
+
for i in range(depth)])
|
| 60 |
+
|
| 61 |
+
# patch merging layer
|
| 62 |
+
if downsample is not None:
|
| 63 |
+
self.downsample = downsample(
|
| 64 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
| 65 |
+
else:
|
| 66 |
+
self.downsample = None
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
for blk in self.blocks:
|
| 70 |
+
if self.use_checkpoint:
|
| 71 |
+
x = checkpoint.checkpoint(blk, x)
|
| 72 |
+
else:
|
| 73 |
+
x = blk(x)
|
| 74 |
+
if self.downsample is not None:
|
| 75 |
+
x = self.downsample(x)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
class FasterTinyViT(nn.Module):
|
| 79 |
+
def __init__(self, img_size=224,
|
| 80 |
+
in_chans=3,
|
| 81 |
+
out_chans=256,
|
| 82 |
+
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
|
| 83 |
+
num_heads=[3, 6, 12, 24],
|
| 84 |
+
window_sizes=[7, 7, 14, 7],
|
| 85 |
+
mlp_ratio=4.,
|
| 86 |
+
drop_rate=0.,
|
| 87 |
+
drop_path_rate=0.1,
|
| 88 |
+
use_checkpoint=False,
|
| 89 |
+
mbconv_expand_ratio=4.0,
|
| 90 |
+
ct_size=2,
|
| 91 |
+
conv=False,
|
| 92 |
+
multi_scale=False,
|
| 93 |
+
output_shape=None,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.img_size = img_size
|
| 97 |
+
self.depths = depths
|
| 98 |
+
self.num_layers = len(depths)
|
| 99 |
+
self.mlp_ratio = mlp_ratio
|
| 100 |
+
self.multi_scale = multi_scale
|
| 101 |
+
self.output_shape = tuple(output_shape) if output_shape else None
|
| 102 |
+
|
| 103 |
+
activation = nn.GELU
|
| 104 |
+
|
| 105 |
+
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
| 106 |
+
embed_dim=embed_dims[0],
|
| 107 |
+
resolution=img_size,
|
| 108 |
+
activation=activation)
|
| 109 |
+
|
| 110 |
+
patches_resolution = self.patch_embed.patches_resolution
|
| 111 |
+
self.patches_resolution = patches_resolution
|
| 112 |
+
|
| 113 |
+
# stochastic depth
|
| 114 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
|
| 115 |
+
sum(depths))] # stochastic depth decay rule
|
| 116 |
+
|
| 117 |
+
# build layers
|
| 118 |
+
self.layers = nn.ModuleList()
|
| 119 |
+
for i_layer in range(self.num_layers):
|
| 120 |
+
kwargs_0 = dict(dim=embed_dims[i_layer],
|
| 121 |
+
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
| 122 |
+
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
|
| 123 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 124 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
| 125 |
+
depth=depths[i_layer],
|
| 126 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 127 |
+
downsample=PatchMerging if (
|
| 128 |
+
i_layer < self.num_layers - 1) else None,
|
| 129 |
+
use_checkpoint=use_checkpoint,
|
| 130 |
+
out_dim=embed_dims[min(
|
| 131 |
+
i_layer + 1, len(embed_dims) - 1)],
|
| 132 |
+
activation=activation,
|
| 133 |
+
)
|
| 134 |
+
kwargs_1 = dict(dim=embed_dims[i_layer],
|
| 135 |
+
out_dim=embed_dims[i_layer+1] if (
|
| 136 |
+
i_layer < self.num_layers - 1) else embed_dims[i_layer],
|
| 137 |
+
input_resolution=patches_resolution[0] // (2 ** i_layer),
|
| 138 |
+
depth=depths[i_layer],
|
| 139 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 140 |
+
downsample=True if (i_layer < self.num_layers - 1) else False,
|
| 141 |
+
ct_size=ct_size,
|
| 142 |
+
conv=conv,
|
| 143 |
+
)
|
| 144 |
+
if i_layer == 0:
|
| 145 |
+
layer = ConvLayer(
|
| 146 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
| 147 |
+
**kwargs_0,
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
layer = FasterViTLayer(
|
| 151 |
+
num_heads=num_heads[i_layer],
|
| 152 |
+
window_size=window_sizes[i_layer],
|
| 153 |
+
mlp_ratio=self.mlp_ratio,
|
| 154 |
+
drop=drop_rate,
|
| 155 |
+
**kwargs_1)
|
| 156 |
+
self.layers.append(layer)
|
| 157 |
+
|
| 158 |
+
# init weights
|
| 159 |
+
self.apply(self._init_weights)
|
| 160 |
+
|
| 161 |
+
self.neck = nn.Sequential(
|
| 162 |
+
nn.Conv2d(
|
| 163 |
+
sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1],
|
| 164 |
+
out_chans,
|
| 165 |
+
kernel_size=1,
|
| 166 |
+
bias=False,
|
| 167 |
+
),
|
| 168 |
+
LayerNorm2d(out_chans),
|
| 169 |
+
nn.Conv2d(
|
| 170 |
+
out_chans,
|
| 171 |
+
out_chans,
|
| 172 |
+
kernel_size=3,
|
| 173 |
+
padding=1,
|
| 174 |
+
bias=False,
|
| 175 |
+
),
|
| 176 |
+
LayerNorm2d(out_chans),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _init_weights(self, m):
|
| 180 |
+
if isinstance(m, nn.Linear):
|
| 181 |
+
trunc_normal_(m.weight, std=.02)
|
| 182 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 183 |
+
nn.init.constant_(m.bias, 0)
|
| 184 |
+
elif isinstance(m, nn.LayerNorm):
|
| 185 |
+
nn.init.constant_(m.bias, 0)
|
| 186 |
+
nn.init.constant_(m.weight, 1.0)
|
| 187 |
+
|
| 188 |
+
@torch.jit.ignore
|
| 189 |
+
def no_weight_decay_keywords(self):
|
| 190 |
+
return {'attention_biases'}
|
| 191 |
+
|
| 192 |
+
def forward_features(self, x):
|
| 193 |
+
if self.multi_scale and self.output_shape:
|
| 194 |
+
output_list = []
|
| 195 |
+
# x: (N, C, H, W)
|
| 196 |
+
x = self.patch_embed(x)
|
| 197 |
+
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
|
| 198 |
+
for layer in self.layers:
|
| 199 |
+
x = layer(x)
|
| 200 |
+
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
|
| 201 |
+
x = self.neck(torch.cat(output_list, dim=1))
|
| 202 |
+
|
| 203 |
+
else:
|
| 204 |
+
x = self.patch_embed(x)
|
| 205 |
+
for layer in self.layers:
|
| 206 |
+
x = layer(x)
|
| 207 |
+
x = self.neck(x)
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
x = self.forward_features(x)
|
| 213 |
+
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
if __name__ == '__main__':
|
| 217 |
+
from distillation.utils import get_parameter_number
|
| 218 |
+
x = torch.randn(1, 3, 1024, 1024).cuda()
|
| 219 |
+
fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3,
|
| 220 |
+
embed_dims=[64, 128, 256],
|
| 221 |
+
depths=[1, 2, 1],
|
| 222 |
+
num_heads=[2, 4, 8],
|
| 223 |
+
window_sizes=[8, 8, 8],
|
| 224 |
+
mlp_ratio=4.,
|
| 225 |
+
drop_rate=0.,
|
| 226 |
+
drop_path_rate=0.0,
|
| 227 |
+
use_checkpoint=False,
|
| 228 |
+
mbconv_expand_ratio=4.0,
|
| 229 |
+
multi_scale=False,
|
| 230 |
+
output_shape='').cuda()
|
| 231 |
+
print(fastertinyvit(x).shape)
|
| 232 |
+
print(get_parameter_number(fastertinyvit))
|
| 233 |
+
# torch.save(fastertinyvit, 'fastertinyvit.pt')
|
sam_extension/distillation_models/fastervit.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from timm.models.layers import DropPath, LayerNorm2d
|
| 5 |
+
def window_partition(x, window_size):
|
| 6 |
+
B, C, H, W = x.shape
|
| 7 |
+
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
|
| 8 |
+
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
|
| 9 |
+
return windows
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def window_reverse(windows, window_size, H, W, B):
|
| 13 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 14 |
+
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], H, W)
|
| 15 |
+
return x
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def ct_dewindow(ct, W, H, window_size):
|
| 19 |
+
bs = ct.shape[0]
|
| 20 |
+
N=ct.shape[2]
|
| 21 |
+
ct2 = ct.view(-1, W//window_size, H//window_size, window_size, window_size, N).permute(0, 5, 1, 3, 2, 4)
|
| 22 |
+
ct2 = ct2.reshape(bs, N, W*H).transpose(1, 2)
|
| 23 |
+
return ct2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def ct_window(ct, W, H, window_size):
|
| 27 |
+
bs = ct.shape[0]
|
| 28 |
+
N = ct.shape[2]
|
| 29 |
+
ct = ct.view(bs, H // window_size, window_size, W // window_size, window_size, N)
|
| 30 |
+
ct = ct.permute(0, 1, 3, 2, 4, 5)
|
| 31 |
+
return ct
|
| 32 |
+
|
| 33 |
+
class PosEmbMLPSwinv2D(nn.Module):
|
| 34 |
+
def __init__(self,
|
| 35 |
+
window_size,
|
| 36 |
+
pretrained_window_size,
|
| 37 |
+
num_heads, seq_length,
|
| 38 |
+
ct_correct=False,
|
| 39 |
+
no_log=False):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.window_size = window_size
|
| 42 |
+
self.num_heads = num_heads
|
| 43 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
nn.Linear(512, num_heads, bias=False))
|
| 46 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
| 47 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
| 48 |
+
relative_coords_table = torch.stack(
|
| 49 |
+
torch.meshgrid([relative_coords_h,
|
| 50 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
| 51 |
+
if pretrained_window_size[0] > 0:
|
| 52 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
| 53 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
| 54 |
+
else:
|
| 55 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
| 56 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
| 57 |
+
|
| 58 |
+
if not no_log:
|
| 59 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
| 60 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
| 61 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
| 62 |
+
|
| 63 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
| 64 |
+
coords_h = torch.arange(self.window_size[0])
|
| 65 |
+
coords_w = torch.arange(self.window_size[1])
|
| 66 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
| 67 |
+
coords_flatten = torch.flatten(coords, 1)
|
| 68 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 69 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 70 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 71 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 72 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 73 |
+
relative_position_index = relative_coords.sum(-1)
|
| 74 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 75 |
+
self.grid_exists = False
|
| 76 |
+
self.pos_emb = None
|
| 77 |
+
self.deploy = False
|
| 78 |
+
relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
|
| 79 |
+
self.seq_length = seq_length
|
| 80 |
+
self.register_buffer("relative_bias", relative_bias)
|
| 81 |
+
self.ct_correct=ct_correct
|
| 82 |
+
|
| 83 |
+
def switch_to_deploy(self):
|
| 84 |
+
self.deploy = True
|
| 85 |
+
|
| 86 |
+
def forward(self, input_tensor, local_window_size):
|
| 87 |
+
if self.deploy:
|
| 88 |
+
input_tensor += self.relative_bias
|
| 89 |
+
return input_tensor
|
| 90 |
+
else:
|
| 91 |
+
self.grid_exists = False
|
| 92 |
+
|
| 93 |
+
if not self.grid_exists:
|
| 94 |
+
self.grid_exists = True
|
| 95 |
+
|
| 96 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
| 97 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 98 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
|
| 99 |
+
-1)
|
| 100 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
| 101 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
| 102 |
+
n_global_feature = input_tensor.shape[2] - local_window_size
|
| 103 |
+
if n_global_feature > 0 and self.ct_correct:
|
| 104 |
+
|
| 105 |
+
step_for_ct=self.window_size[0]/(n_global_feature**0.5+1)
|
| 106 |
+
seq_length = int(n_global_feature ** 0.5)
|
| 107 |
+
indices = []
|
| 108 |
+
for i in range(seq_length):
|
| 109 |
+
for j in range(seq_length):
|
| 110 |
+
ind = (i+1)*step_for_ct*self.window_size[0] + (j+1)*step_for_ct
|
| 111 |
+
indices.append(int(ind))
|
| 112 |
+
|
| 113 |
+
top_part = relative_position_bias[:, indices, :]
|
| 114 |
+
lefttop_part = relative_position_bias[:, indices, :][:, :, indices]
|
| 115 |
+
left_part = relative_position_bias[:, :, indices]
|
| 116 |
+
relative_position_bias = torch.nn.functional.pad(relative_position_bias, (n_global_feature,
|
| 117 |
+
0,
|
| 118 |
+
n_global_feature,
|
| 119 |
+
0)).contiguous()
|
| 120 |
+
if n_global_feature>0 and self.ct_correct:
|
| 121 |
+
relative_position_bias = relative_position_bias*0.0
|
| 122 |
+
relative_position_bias[:, :n_global_feature, :n_global_feature] = lefttop_part
|
| 123 |
+
relative_position_bias[:, :n_global_feature, n_global_feature:] = top_part
|
| 124 |
+
relative_position_bias[:, n_global_feature:, :n_global_feature] = left_part
|
| 125 |
+
|
| 126 |
+
self.pos_emb = relative_position_bias.unsqueeze(0)
|
| 127 |
+
self.relative_bias = self.pos_emb
|
| 128 |
+
|
| 129 |
+
input_tensor += self.pos_emb
|
| 130 |
+
return input_tensor
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PosEmbMLPSwinv1D(nn.Module):
|
| 134 |
+
def __init__(self,
|
| 135 |
+
dim,
|
| 136 |
+
rank=2,
|
| 137 |
+
seq_length=4,
|
| 138 |
+
conv=False):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.rank = rank
|
| 141 |
+
if not conv:
|
| 142 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(self.rank, 512, bias=True),
|
| 143 |
+
nn.ReLU(),
|
| 144 |
+
nn.Linear(512, dim, bias=False))
|
| 145 |
+
else:
|
| 146 |
+
self.cpb_mlp = nn.Sequential(nn.Conv1d(self.rank, 512, 1,bias=True),
|
| 147 |
+
nn.ReLU(),
|
| 148 |
+
nn.Conv1d(512, dim, 1,bias=False))
|
| 149 |
+
self.grid_exists = False
|
| 150 |
+
self.pos_emb = None
|
| 151 |
+
self.deploy = False
|
| 152 |
+
relative_bias = torch.zeros(1,seq_length, dim)
|
| 153 |
+
self.register_buffer("relative_bias", relative_bias)
|
| 154 |
+
self.conv = conv
|
| 155 |
+
|
| 156 |
+
def switch_to_deploy(self):
|
| 157 |
+
self.deploy = True
|
| 158 |
+
|
| 159 |
+
def forward(self, input_tensor):
|
| 160 |
+
seq_length = input_tensor.shape[1] if not self.conv else input_tensor.shape[2]
|
| 161 |
+
if self.deploy:
|
| 162 |
+
return input_tensor + self.relative_bias
|
| 163 |
+
else:
|
| 164 |
+
self.grid_exists = False
|
| 165 |
+
if not self.grid_exists:
|
| 166 |
+
self.grid_exists = True
|
| 167 |
+
if self.rank == 1:
|
| 168 |
+
relative_coords_h = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
|
| 169 |
+
relative_coords_h -= seq_length//2
|
| 170 |
+
relative_coords_h /= (seq_length//2)
|
| 171 |
+
relative_coords_table = relative_coords_h
|
| 172 |
+
self.pos_emb = self.cpb_mlp(relative_coords_table.unsqueeze(0).unsqueeze(2))
|
| 173 |
+
self.relative_bias = self.pos_emb
|
| 174 |
+
else:
|
| 175 |
+
seq_length = int(seq_length**0.5)
|
| 176 |
+
relative_coords_h = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
|
| 177 |
+
relative_coords_w = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
|
| 178 |
+
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])).contiguous().unsqueeze(0)
|
| 179 |
+
relative_coords_table -= seq_length // 2
|
| 180 |
+
relative_coords_table /= (seq_length // 2)
|
| 181 |
+
if not self.conv:
|
| 182 |
+
self.pos_emb = self.cpb_mlp(relative_coords_table.flatten(2).transpose(1,2))
|
| 183 |
+
else:
|
| 184 |
+
self.pos_emb = self.cpb_mlp(relative_coords_table.flatten(2))
|
| 185 |
+
self.relative_bias = self.pos_emb
|
| 186 |
+
input_tensor = input_tensor + self.pos_emb
|
| 187 |
+
return input_tensor
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class Mlp(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
Multi-Layer Perceptron (MLP) block
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(self,
|
| 196 |
+
in_features,
|
| 197 |
+
hidden_features=None,
|
| 198 |
+
out_features=None,
|
| 199 |
+
act_layer=nn.GELU,
|
| 200 |
+
drop=0.):
|
| 201 |
+
"""
|
| 202 |
+
Args:
|
| 203 |
+
in_features: input features dimension.
|
| 204 |
+
hidden_features: hidden features dimension.
|
| 205 |
+
out_features: output features dimension.
|
| 206 |
+
act_layer: activation function.
|
| 207 |
+
drop: dropout rate.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
super().__init__()
|
| 211 |
+
out_features = out_features or in_features
|
| 212 |
+
hidden_features = hidden_features or in_features
|
| 213 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 214 |
+
self.act = act_layer()
|
| 215 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 216 |
+
self.drop = nn.Dropout(drop)
|
| 217 |
+
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
x_size = x.size()
|
| 220 |
+
x = x.view(-1, x_size[-1])
|
| 221 |
+
x = self.fc1(x)
|
| 222 |
+
x = self.act(x)
|
| 223 |
+
x = self.drop(x)
|
| 224 |
+
x = self.fc2(x)
|
| 225 |
+
x = self.drop(x)
|
| 226 |
+
x = x.view(x_size)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
class Downsample(nn.Module):
|
| 230 |
+
"""
|
| 231 |
+
Down-sampling block based on: "Hatamizadeh et al.,
|
| 232 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
def __init__(self,
|
| 236 |
+
dim,
|
| 237 |
+
out_dim,
|
| 238 |
+
keep_dim=False,
|
| 239 |
+
stride=2,
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
Args:
|
| 243 |
+
dim: feature size dimension.
|
| 244 |
+
norm_layer: normalization layer.
|
| 245 |
+
keep_dim: bool argument for maintaining the resolution.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
super().__init__()
|
| 249 |
+
if keep_dim:
|
| 250 |
+
out_dim = dim
|
| 251 |
+
self.norm = LayerNorm2d(dim)
|
| 252 |
+
self.reduction = nn.Sequential(
|
| 253 |
+
nn.Conv2d(dim, out_dim, 3, stride, 1, bias=False),
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def forward(self, x):
|
| 257 |
+
x = self.norm(x)
|
| 258 |
+
x = self.reduction(x)
|
| 259 |
+
return x
|
| 260 |
+
class PatchEmbed(nn.Module):
|
| 261 |
+
"""
|
| 262 |
+
Patch embedding block based on: "Hatamizadeh et al.,
|
| 263 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self, in_chans=3, in_dim=64, dim=96):
|
| 267 |
+
"""
|
| 268 |
+
Args:
|
| 269 |
+
in_chans: number of input channels.
|
| 270 |
+
dim: feature size dimension.
|
| 271 |
+
"""
|
| 272 |
+
super().__init__()
|
| 273 |
+
self.proj = nn.Identity()
|
| 274 |
+
self.conv_down = nn.Sequential(
|
| 275 |
+
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
|
| 276 |
+
nn.BatchNorm2d(in_dim, eps=1e-4),
|
| 277 |
+
nn.ReLU(),
|
| 278 |
+
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
|
| 279 |
+
nn.BatchNorm2d(dim, eps=1e-4),
|
| 280 |
+
nn.ReLU()
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
x = self.proj(x)
|
| 285 |
+
x = self.conv_down(x)
|
| 286 |
+
return x
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class ConvBlock(nn.Module):
|
| 290 |
+
"""
|
| 291 |
+
Conv block based on: "Hatamizadeh et al.,
|
| 292 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, dim,
|
| 296 |
+
drop_path=0.,
|
| 297 |
+
layer_scale=None,
|
| 298 |
+
kernel_size=3):
|
| 299 |
+
super().__init__()
|
| 300 |
+
"""
|
| 301 |
+
Args:
|
| 302 |
+
drop_path: drop path.
|
| 303 |
+
layer_scale: layer scale coefficient.
|
| 304 |
+
kernel_size: kernel size.
|
| 305 |
+
"""
|
| 306 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
| 307 |
+
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
|
| 308 |
+
self.act1 = nn.GELU()
|
| 309 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
| 310 |
+
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
|
| 311 |
+
self.layer_scale = layer_scale
|
| 312 |
+
if layer_scale is not None and type(layer_scale) in [int, float]:
|
| 313 |
+
self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
|
| 314 |
+
self.layer_scale = True
|
| 315 |
+
else:
|
| 316 |
+
self.layer_scale = False
|
| 317 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 318 |
+
|
| 319 |
+
def forward(self, x, global_feature=None):
|
| 320 |
+
input = x
|
| 321 |
+
x = self.conv1(x)
|
| 322 |
+
x = self.norm1(x)
|
| 323 |
+
x = self.act1(x)
|
| 324 |
+
x = self.conv2(x)
|
| 325 |
+
x = self.norm2(x)
|
| 326 |
+
if self.layer_scale:
|
| 327 |
+
x = x * self.gamma.view(1, -1, 1, 1)
|
| 328 |
+
x = input + self.drop_path(x)
|
| 329 |
+
return x, global_feature
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class WindowAttention(nn.Module):
|
| 333 |
+
"""
|
| 334 |
+
Window attention based on: "Hatamizadeh et al.,
|
| 335 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
| 336 |
+
"""
|
| 337 |
+
def __init__(self,
|
| 338 |
+
dim,
|
| 339 |
+
num_heads=8,
|
| 340 |
+
qkv_bias=False,
|
| 341 |
+
qk_scale=None,
|
| 342 |
+
attn_drop=0.,
|
| 343 |
+
proj_drop=0.,
|
| 344 |
+
resolution=0,
|
| 345 |
+
seq_length=0):
|
| 346 |
+
super().__init__()
|
| 347 |
+
"""
|
| 348 |
+
Args:
|
| 349 |
+
dim: feature size dimension.
|
| 350 |
+
num_heads: number of attention head.
|
| 351 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
| 352 |
+
qk_scale: bool argument to scaling query, key.
|
| 353 |
+
attn_drop: attention dropout rate.
|
| 354 |
+
proj_drop: output dropout rate.
|
| 355 |
+
resolution: feature resolution.
|
| 356 |
+
seq_length: sequence length.
|
| 357 |
+
"""
|
| 358 |
+
self.num_heads = num_heads
|
| 359 |
+
head_dim = dim // num_heads
|
| 360 |
+
self.head_dim = dim // num_heads
|
| 361 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 362 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 363 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 364 |
+
self.proj = nn.Linear(dim, dim)
|
| 365 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 366 |
+
# attention positional bias
|
| 367 |
+
self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
|
| 368 |
+
pretrained_window_size=[resolution, resolution],
|
| 369 |
+
num_heads=num_heads,
|
| 370 |
+
seq_length=seq_length)
|
| 371 |
+
|
| 372 |
+
self.resolution = resolution
|
| 373 |
+
|
| 374 |
+
def forward(self, x):
|
| 375 |
+
B, N, C = x.shape
|
| 376 |
+
qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 377 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 378 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 379 |
+
attn = self.pos_emb_funct(attn, self.resolution ** 2)
|
| 380 |
+
attn = attn.softmax(dim=-1)
|
| 381 |
+
attn = self.attn_drop(attn)
|
| 382 |
+
x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
|
| 383 |
+
x = self.proj(x)
|
| 384 |
+
x = self.proj_drop(x)
|
| 385 |
+
return x
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class HAT(nn.Module):
|
| 389 |
+
"""
|
| 390 |
+
Hierarchical attention (HAT) based on: "Hatamizadeh et al.,
|
| 391 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
| 392 |
+
"""
|
| 393 |
+
def __init__(self,
|
| 394 |
+
dim,
|
| 395 |
+
num_heads,
|
| 396 |
+
mlp_ratio=4.,
|
| 397 |
+
qkv_bias=False,
|
| 398 |
+
qk_scale=None,
|
| 399 |
+
drop=0.,
|
| 400 |
+
attn_drop=0.,
|
| 401 |
+
drop_path=0.,
|
| 402 |
+
act_layer=nn.GELU,
|
| 403 |
+
norm_layer=nn.LayerNorm,
|
| 404 |
+
sr_ratio=1.,
|
| 405 |
+
window_size=7,
|
| 406 |
+
last=False,
|
| 407 |
+
layer_scale=None,
|
| 408 |
+
ct_size=1,
|
| 409 |
+
do_propagation=False):
|
| 410 |
+
super().__init__()
|
| 411 |
+
"""
|
| 412 |
+
Args:
|
| 413 |
+
dim: feature size dimension.
|
| 414 |
+
num_heads: number of attention head.
|
| 415 |
+
mlp_ratio: MLP ratio.
|
| 416 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
| 417 |
+
qk_scale: bool argument to scaling query, key.
|
| 418 |
+
drop: dropout rate.
|
| 419 |
+
attn_drop: attention dropout rate.
|
| 420 |
+
proj_drop: output dropout rate.
|
| 421 |
+
act_layer: activation function.
|
| 422 |
+
norm_layer: normalization layer.
|
| 423 |
+
sr_ratio: input to window size ratio.
|
| 424 |
+
window_size: window size.
|
| 425 |
+
last: last layer flag.
|
| 426 |
+
layer_scale: layer scale coefficient.
|
| 427 |
+
ct_size: spatial dimension of carrier token local window.
|
| 428 |
+
do_propagation: enable carrier token propagation.
|
| 429 |
+
"""
|
| 430 |
+
# positional encoding for windowed attention tokens
|
| 431 |
+
self.pos_embed = PosEmbMLPSwinv1D(dim, rank=2, seq_length=window_size**2)
|
| 432 |
+
self.norm1 = norm_layer(dim)
|
| 433 |
+
# number of carrier tokens per every window
|
| 434 |
+
cr_tokens_per_window = ct_size**2 if sr_ratio > 1 else 0
|
| 435 |
+
# total number of carrier tokens
|
| 436 |
+
cr_tokens_total = cr_tokens_per_window*sr_ratio*sr_ratio
|
| 437 |
+
self.cr_window = ct_size
|
| 438 |
+
self.attn = WindowAttention(dim,
|
| 439 |
+
num_heads=num_heads,
|
| 440 |
+
qkv_bias=qkv_bias,
|
| 441 |
+
qk_scale=qk_scale,
|
| 442 |
+
attn_drop=attn_drop,
|
| 443 |
+
proj_drop=drop,
|
| 444 |
+
resolution=window_size,
|
| 445 |
+
seq_length=window_size**2 + cr_tokens_per_window)
|
| 446 |
+
|
| 447 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 448 |
+
self.norm2 = norm_layer(dim)
|
| 449 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 450 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 451 |
+
self.window_size = window_size
|
| 452 |
+
|
| 453 |
+
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
|
| 454 |
+
self.gamma3 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
| 455 |
+
self.gamma4 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
| 456 |
+
|
| 457 |
+
self.sr_ratio = sr_ratio
|
| 458 |
+
if sr_ratio > 1:
|
| 459 |
+
# if do hierarchical attention, this part is for carrier tokens
|
| 460 |
+
self.hat_norm1 = norm_layer(dim)
|
| 461 |
+
self.hat_norm2 = norm_layer(dim)
|
| 462 |
+
self.hat_attn = WindowAttention(
|
| 463 |
+
dim,
|
| 464 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 465 |
+
attn_drop=attn_drop, proj_drop=drop, resolution=int(cr_tokens_total**0.5),
|
| 466 |
+
seq_length=cr_tokens_total)
|
| 467 |
+
|
| 468 |
+
self.hat_mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 469 |
+
self.hat_drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 470 |
+
self.hat_pos_embed = PosEmbMLPSwinv1D(dim, rank=2, seq_length=cr_tokens_total)
|
| 471 |
+
self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
| 472 |
+
self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
| 473 |
+
self.upsampler = nn.Upsample(size=window_size, mode='nearest')
|
| 474 |
+
|
| 475 |
+
# keep track for the last block to explicitly add carrier tokens to feature maps
|
| 476 |
+
self.last = last
|
| 477 |
+
self.do_propagation = do_propagation
|
| 478 |
+
|
| 479 |
+
def forward(self, x, carrier_tokens):
|
| 480 |
+
B, T, N = x.shape
|
| 481 |
+
ct = carrier_tokens
|
| 482 |
+
x = self.pos_embed(x)
|
| 483 |
+
|
| 484 |
+
if self.sr_ratio > 1:
|
| 485 |
+
# do hierarchical attention via carrier tokens
|
| 486 |
+
# first do attention for carrier tokens
|
| 487 |
+
Bg, Ng, Hg = ct.shape
|
| 488 |
+
|
| 489 |
+
# ct are located quite differently
|
| 490 |
+
ct = ct_dewindow(ct, self.cr_window*self.sr_ratio, self.cr_window*self.sr_ratio, self.cr_window)
|
| 491 |
+
|
| 492 |
+
# positional bias for carrier tokens
|
| 493 |
+
ct = self.hat_pos_embed(ct)
|
| 494 |
+
|
| 495 |
+
# attention plus mlp
|
| 496 |
+
ct = ct + self.hat_drop_path(self.gamma1*self.hat_attn(self.hat_norm1(ct)))
|
| 497 |
+
ct = ct + self.hat_drop_path(self.gamma2*self.hat_mlp(self.hat_norm2(ct)))
|
| 498 |
+
|
| 499 |
+
# ct are put back to windows
|
| 500 |
+
ct = ct_window(ct, self.cr_window * self.sr_ratio, self.cr_window * self.sr_ratio, self.cr_window)
|
| 501 |
+
|
| 502 |
+
ct = ct.reshape(x.shape[0], -1, N)
|
| 503 |
+
# concatenate carrier_tokens to the windowed tokens
|
| 504 |
+
x = torch.cat((ct, x), dim=1)
|
| 505 |
+
|
| 506 |
+
# window attention together with carrier tokens
|
| 507 |
+
x = x + self.drop_path(self.gamma3*self.attn(self.norm1(x)))
|
| 508 |
+
x = x + self.drop_path(self.gamma4*self.mlp(self.norm2(x)))
|
| 509 |
+
|
| 510 |
+
if self.sr_ratio > 1:
|
| 511 |
+
# for hierarchical attention we need to split carrier tokens and window tokens back
|
| 512 |
+
ctr, x = x.split([x.shape[1] - self.window_size*self.window_size, self.window_size*self.window_size], dim=1)
|
| 513 |
+
ct = ctr.reshape(Bg, Ng, Hg) # reshape carrier tokens.
|
| 514 |
+
if self.last and self.do_propagation:
|
| 515 |
+
# propagate carrier token information into the image
|
| 516 |
+
ctr_image_space = ctr.transpose(1, 2).reshape(B, N, self.cr_window, self.cr_window)
|
| 517 |
+
x = x + self.gamma1 * self.upsampler(ctr_image_space.to(dtype=torch.float32)).flatten(2).transpose(1, 2).to(dtype=x.dtype)
|
| 518 |
+
return x, ct
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class TokenInitializer(nn.Module):
|
| 522 |
+
"""
|
| 523 |
+
Carrier token Initializer based on: "Hatamizadeh et al.,
|
| 524 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
| 525 |
+
"""
|
| 526 |
+
def __init__(self,
|
| 527 |
+
dim,
|
| 528 |
+
input_resolution,
|
| 529 |
+
window_size,
|
| 530 |
+
ct_size=1):
|
| 531 |
+
"""
|
| 532 |
+
Args:
|
| 533 |
+
dim: feature size dimension.
|
| 534 |
+
input_resolution: input image resolution.
|
| 535 |
+
window_size: window size.
|
| 536 |
+
ct_size: spatial dimension of carrier token local window
|
| 537 |
+
"""
|
| 538 |
+
super().__init__()
|
| 539 |
+
|
| 540 |
+
output_size = int(ct_size * input_resolution/window_size)
|
| 541 |
+
stride_size = int(input_resolution/output_size)
|
| 542 |
+
kernel_size = input_resolution - (output_size - 1) * stride_size
|
| 543 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
| 544 |
+
to_global_feature = nn.Sequential()
|
| 545 |
+
to_global_feature.add_module("pos", self.pos_embed)
|
| 546 |
+
to_global_feature.add_module("pool", nn.AvgPool2d(kernel_size=kernel_size, stride=stride_size))
|
| 547 |
+
self.to_global_feature = to_global_feature
|
| 548 |
+
self.window_size = ct_size
|
| 549 |
+
|
| 550 |
+
def forward(self, x):
|
| 551 |
+
x = self.to_global_feature(x)
|
| 552 |
+
B, C, H, W = x.shape
|
| 553 |
+
ct = x.view(B, C, H // self.window_size, self.window_size, W // self.window_size, self.window_size)
|
| 554 |
+
ct = ct.permute(0, 2, 4, 3, 5, 1).reshape(-1, H*W, C)
|
| 555 |
+
return ct
|
| 556 |
+
class FasterViTLayer(nn.Module):
|
| 557 |
+
"""
|
| 558 |
+
GCViT layer based on: "Hatamizadeh et al.,
|
| 559 |
+
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
def __init__(self,
|
| 563 |
+
dim,
|
| 564 |
+
out_dim,
|
| 565 |
+
depth,
|
| 566 |
+
input_resolution,
|
| 567 |
+
num_heads,
|
| 568 |
+
window_size,
|
| 569 |
+
ct_size=1,
|
| 570 |
+
conv=False,
|
| 571 |
+
downsample=True,
|
| 572 |
+
mlp_ratio=4.,
|
| 573 |
+
qkv_bias=True,
|
| 574 |
+
qk_scale=None,
|
| 575 |
+
drop=0.,
|
| 576 |
+
attn_drop=0.,
|
| 577 |
+
drop_path=0.,
|
| 578 |
+
layer_scale=None,
|
| 579 |
+
layer_scale_conv=None,
|
| 580 |
+
only_local=False,
|
| 581 |
+
hierarchy=True,
|
| 582 |
+
do_propagation=False
|
| 583 |
+
):
|
| 584 |
+
"""
|
| 585 |
+
Args:
|
| 586 |
+
dim: feature size dimension.
|
| 587 |
+
depth: layer depth.
|
| 588 |
+
input_resolution: input resolution.
|
| 589 |
+
num_heads: number of attention head.
|
| 590 |
+
window_size: window size.
|
| 591 |
+
ct_size: spatial dimension of carrier token local window.
|
| 592 |
+
conv: conv_based stage flag.
|
| 593 |
+
downsample: downsample flag.
|
| 594 |
+
mlp_ratio: MLP ratio.
|
| 595 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
| 596 |
+
qk_scale: bool argument to scaling query, key.
|
| 597 |
+
drop: dropout rate.
|
| 598 |
+
attn_drop: attention dropout rate.
|
| 599 |
+
drop_path: drop path rate.
|
| 600 |
+
layer_scale: layer scale coefficient.
|
| 601 |
+
layer_scale_conv: conv layer scale coefficient.
|
| 602 |
+
only_local: local attention flag.
|
| 603 |
+
hierarchy: hierarchical attention flag.
|
| 604 |
+
do_propagation: enable carrier token propagation.
|
| 605 |
+
"""
|
| 606 |
+
super().__init__()
|
| 607 |
+
self.conv = conv
|
| 608 |
+
self.transformer_block = False
|
| 609 |
+
if conv:
|
| 610 |
+
self.blocks = nn.ModuleList([
|
| 611 |
+
ConvBlock(dim=dim,
|
| 612 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 613 |
+
layer_scale=layer_scale_conv)
|
| 614 |
+
for i in range(depth)])
|
| 615 |
+
self.transformer_block = False
|
| 616 |
+
else:
|
| 617 |
+
sr_ratio = input_resolution // window_size if not only_local else 1
|
| 618 |
+
self.blocks = nn.ModuleList([
|
| 619 |
+
HAT(dim=dim,
|
| 620 |
+
num_heads=num_heads,
|
| 621 |
+
mlp_ratio=mlp_ratio,
|
| 622 |
+
qkv_bias=qkv_bias,
|
| 623 |
+
qk_scale=qk_scale,
|
| 624 |
+
drop=drop,
|
| 625 |
+
attn_drop=attn_drop,
|
| 626 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 627 |
+
sr_ratio=sr_ratio,
|
| 628 |
+
window_size=window_size,
|
| 629 |
+
last=(i == depth-1),
|
| 630 |
+
layer_scale=layer_scale,
|
| 631 |
+
ct_size=ct_size,
|
| 632 |
+
do_propagation=do_propagation,
|
| 633 |
+
)
|
| 634 |
+
for i in range(depth)])
|
| 635 |
+
self.transformer_block = True
|
| 636 |
+
self.downsample = Downsample(dim=dim, out_dim=out_dim, stride=1) if not downsample else Downsample(dim=dim, out_dim=out_dim, stride=2)
|
| 637 |
+
if len(self.blocks) and not only_local and input_resolution // window_size > 1 and hierarchy and not self.conv:
|
| 638 |
+
self.global_tokenizer = TokenInitializer(dim,
|
| 639 |
+
input_resolution,
|
| 640 |
+
window_size,
|
| 641 |
+
ct_size=ct_size)
|
| 642 |
+
self.do_gt = True
|
| 643 |
+
else:
|
| 644 |
+
self.do_gt = False
|
| 645 |
+
|
| 646 |
+
self.window_size = window_size
|
| 647 |
+
|
| 648 |
+
def forward(self, x):
|
| 649 |
+
ct = self.global_tokenizer(x) if self.do_gt else None
|
| 650 |
+
B, C, H, W = x.shape
|
| 651 |
+
if self.transformer_block:
|
| 652 |
+
x = window_partition(x, self.window_size)
|
| 653 |
+
for bn, blk in enumerate(self.blocks):
|
| 654 |
+
x, ct = blk(x, ct)
|
| 655 |
+
if self.transformer_block:
|
| 656 |
+
x = window_reverse(x, self.window_size, H, W, B)
|
| 657 |
+
if self.downsample is None:
|
| 658 |
+
return x
|
| 659 |
+
return self.downsample(x)
|
sam_extension/distillation_models/sam.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
from typing import Optional, List, Union, Tuple, Type
|
| 11 |
+
|
| 12 |
+
from segment_anything import build_sam
|
| 13 |
+
from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT
|
| 14 |
+
from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer
|
| 15 |
+
from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention
|
| 16 |
+
from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam
|
| 17 |
+
from segment_anything.modeling.sam import Sam
|
| 18 |
+
|
| 19 |
+
from sam_extension.distillation_models.fastertinyvit import FasterTinyViT
|
| 20 |
+
from sam_extension.distillation_models.dino import DINO
|
| 21 |
+
# from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer
|
| 22 |
+
|
| 23 |
+
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
|
| 24 |
+
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SAMImageEncoder(nn.Module):
|
| 28 |
+
def __init__(self,
|
| 29 |
+
sam_checkpoint_path,
|
| 30 |
+
device='cuda'):
|
| 31 |
+
super(SAMImageEncoder, self).__init__()
|
| 32 |
+
sam = build_sam(sam_checkpoint_path).to(device)
|
| 33 |
+
self.image_encoder = sam.image_encoder
|
| 34 |
+
del sam
|
| 35 |
+
torch.cuda.empty_cache()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.image_encoder(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MobileSAMImageEncoder(nn.Module):
|
| 42 |
+
def __init__(self,
|
| 43 |
+
sam_checkpoint_path,
|
| 44 |
+
device='cuda'):
|
| 45 |
+
super(MobileSAMImageEncoder, self).__init__()
|
| 46 |
+
sam = load_mobile_sam(sam_checkpoint_path, device)
|
| 47 |
+
self.image_encoder = sam.image_encoder
|
| 48 |
+
del sam
|
| 49 |
+
torch.cuda.empty_cache()
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
return self.image_encoder(x)
|
| 52 |
+
|
| 53 |
+
class SAMEncoderViT(nn.Module):
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
img_size: int = 1024,
|
| 57 |
+
patch_size: int = 16,
|
| 58 |
+
in_chans: int = 3,
|
| 59 |
+
embed_dim: int = 768,
|
| 60 |
+
depth: int = 12,
|
| 61 |
+
num_heads: int = 12,
|
| 62 |
+
mlp_ratio: float = 4.0,
|
| 63 |
+
out_chans: int = 256,
|
| 64 |
+
qkv_bias: bool = True,
|
| 65 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 66 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 67 |
+
use_abs_pos: bool = True,
|
| 68 |
+
use_rel_pos: bool = False,
|
| 69 |
+
rel_pos_zero_init: bool = True,
|
| 70 |
+
window_size: int = 0,
|
| 71 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
| 72 |
+
multi_scale: bool = False,
|
| 73 |
+
output_shape: Union[Tuple, List] = None
|
| 74 |
+
) -> None:
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
img_size (int): Input image size.
|
| 78 |
+
patch_size (int): Patch size.
|
| 79 |
+
in_chans (int): Number of input image channels.
|
| 80 |
+
embed_dim (int): Patch embedding dimension.
|
| 81 |
+
depth (int): Depth of ViT.
|
| 82 |
+
num_heads (int): Number of attention heads in each ViT block.
|
| 83 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 84 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
| 85 |
+
norm_layer (nn.Module): Normalization layer.
|
| 86 |
+
act_layer (nn.Module): Activation layer.
|
| 87 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
| 88 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
| 89 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
| 90 |
+
window_size (int): Window size for window attention blocks.
|
| 91 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.img_size = img_size
|
| 95 |
+
self.multi_scale = multi_scale
|
| 96 |
+
self.output_shape = tuple(output_shape) if output_shape else None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
self.patch_embed = PatchEmbed(
|
| 100 |
+
kernel_size=(patch_size, patch_size),
|
| 101 |
+
stride=(patch_size, patch_size),
|
| 102 |
+
in_chans=in_chans,
|
| 103 |
+
embed_dim=embed_dim,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
| 107 |
+
if use_abs_pos:
|
| 108 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 109 |
+
self.pos_embed = nn.Parameter(
|
| 110 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.blocks = nn.ModuleList()
|
| 114 |
+
for i in range(depth):
|
| 115 |
+
block = Block(
|
| 116 |
+
dim=embed_dim,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
mlp_ratio=mlp_ratio,
|
| 119 |
+
qkv_bias=qkv_bias,
|
| 120 |
+
norm_layer=norm_layer,
|
| 121 |
+
act_layer=act_layer,
|
| 122 |
+
use_rel_pos=use_rel_pos,
|
| 123 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 124 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
| 125 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
| 126 |
+
)
|
| 127 |
+
self.blocks.append(block)
|
| 128 |
+
|
| 129 |
+
self.neck = nn.Sequential(
|
| 130 |
+
nn.Conv2d(
|
| 131 |
+
embed_dim*depth if self.multi_scale and self.output_shape else embed_dim,
|
| 132 |
+
out_chans,
|
| 133 |
+
kernel_size=1,
|
| 134 |
+
bias=False,
|
| 135 |
+
),
|
| 136 |
+
LayerNorm2d(out_chans),
|
| 137 |
+
nn.Conv2d(
|
| 138 |
+
out_chans,
|
| 139 |
+
out_chans,
|
| 140 |
+
kernel_size=3,
|
| 141 |
+
padding=1,
|
| 142 |
+
bias=False,
|
| 143 |
+
),
|
| 144 |
+
LayerNorm2d(out_chans),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
x = self.patch_embed(x)
|
| 149 |
+
if self.pos_embed is not None:
|
| 150 |
+
x = x + self.pos_embed
|
| 151 |
+
|
| 152 |
+
if self.multi_scale and self.output_shape:
|
| 153 |
+
output_list = []
|
| 154 |
+
for blk in self.blocks:
|
| 155 |
+
x = blk(x)
|
| 156 |
+
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
|
| 157 |
+
|
| 158 |
+
x = self.neck(torch.cat(output_list, dim=1))
|
| 159 |
+
else:
|
| 160 |
+
for blk in self.blocks:
|
| 161 |
+
x = blk(x)
|
| 162 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
class SAMEncoderAdaptor(nn.Module):
|
| 166 |
+
def __init__(self,
|
| 167 |
+
img_size: int,
|
| 168 |
+
input_size: Optional[Tuple[int, int]],
|
| 169 |
+
embed_dim: int = 768,
|
| 170 |
+
depth: int = 12,
|
| 171 |
+
num_heads: int = 12,
|
| 172 |
+
mlp_ratio: float = 4.0,
|
| 173 |
+
out_chans: int = 256,
|
| 174 |
+
qkv_bias: bool = True,
|
| 175 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 176 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 177 |
+
use_abs_pos: bool = True,
|
| 178 |
+
use_rel_pos: bool = False,
|
| 179 |
+
rel_pos_zero_init: bool = True,
|
| 180 |
+
window_size: int = 0,
|
| 181 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
| 182 |
+
multi_scale: bool = False,
|
| 183 |
+
output_shape: Union[Tuple, List] = None):
|
| 184 |
+
super(SAMEncoderAdaptor, self).__init__()
|
| 185 |
+
self.img_size = img_size
|
| 186 |
+
self.multi_scale = multi_scale
|
| 187 |
+
self.output_shape = tuple(output_shape) if output_shape else None
|
| 188 |
+
|
| 189 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
| 190 |
+
if use_abs_pos:
|
| 191 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 192 |
+
self.pos_embed = nn.Parameter(
|
| 193 |
+
torch.zeros(1, input_size[0], input_size[1], embed_dim)
|
| 194 |
+
)
|
| 195 |
+
self.blocks = nn.ModuleList()
|
| 196 |
+
for i in range(depth):
|
| 197 |
+
block = Block(
|
| 198 |
+
dim=embed_dim,
|
| 199 |
+
num_heads=num_heads,
|
| 200 |
+
mlp_ratio=mlp_ratio,
|
| 201 |
+
qkv_bias=qkv_bias,
|
| 202 |
+
norm_layer=norm_layer,
|
| 203 |
+
act_layer=act_layer,
|
| 204 |
+
use_rel_pos=use_rel_pos,
|
| 205 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
| 206 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
| 207 |
+
input_size=input_size,
|
| 208 |
+
)
|
| 209 |
+
self.blocks.append(block)
|
| 210 |
+
|
| 211 |
+
self.neck = nn.Sequential(
|
| 212 |
+
nn.Conv2d(
|
| 213 |
+
embed_dim * depth if self.multi_scale and self.output_shape else embed_dim,
|
| 214 |
+
out_chans,
|
| 215 |
+
kernel_size=1,
|
| 216 |
+
bias=False,
|
| 217 |
+
),
|
| 218 |
+
LayerNorm2d(out_chans),
|
| 219 |
+
nn.Conv2d(
|
| 220 |
+
out_chans,
|
| 221 |
+
out_chans,
|
| 222 |
+
kernel_size=3,
|
| 223 |
+
padding=1,
|
| 224 |
+
bias=False,
|
| 225 |
+
),
|
| 226 |
+
LayerNorm2d(out_chans),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor:
|
| 230 |
+
if original_size:
|
| 231 |
+
original_size = torch.LongTensor(original_size)
|
| 232 |
+
output_shape = x.shape[-2:]
|
| 233 |
+
if original_size.ndim == 1:
|
| 234 |
+
original_size = original_size[None, ...]
|
| 235 |
+
adaptor_inputs = []
|
| 236 |
+
for i in range(original_size.shape[0]):
|
| 237 |
+
h, w = original_size[i]
|
| 238 |
+
if h > w:
|
| 239 |
+
new_h = output_shape[0]
|
| 240 |
+
new_w = int(w * new_h / h)
|
| 241 |
+
else:
|
| 242 |
+
new_w = output_shape[1]
|
| 243 |
+
new_h = int(h * new_w / w)
|
| 244 |
+
encoder_output = x[0].unsqueeze(0)
|
| 245 |
+
encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear')
|
| 246 |
+
pad_h = output_shape[0] - new_h
|
| 247 |
+
pad_w = output_shape[1] - new_w
|
| 248 |
+
encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h))
|
| 249 |
+
adaptor_inputs.append(encoder_output)
|
| 250 |
+
adaptor_inputs = torch.cat(adaptor_inputs, dim=0)
|
| 251 |
+
x = adaptor_inputs.permute(0, 2, 3, 1)
|
| 252 |
+
if self.pos_embed is not None:
|
| 253 |
+
x = x + self.pos_embed
|
| 254 |
+
if self.multi_scale and self.output_shape:
|
| 255 |
+
output_list = []
|
| 256 |
+
for blk in self.blocks:
|
| 257 |
+
x = blk(x)
|
| 258 |
+
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
|
| 259 |
+
|
| 260 |
+
x = self.neck(torch.cat(output_list, dim=1))
|
| 261 |
+
else:
|
| 262 |
+
for blk in self.blocks:
|
| 263 |
+
x = blk(x)
|
| 264 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class DINOSAMViT(nn.Module):
|
| 269 |
+
def __init__(self,
|
| 270 |
+
dino_model_type,
|
| 271 |
+
device='cuda',
|
| 272 |
+
pca_dim=None,
|
| 273 |
+
**kwargs
|
| 274 |
+
):
|
| 275 |
+
super(DINOSAMViT, self).__init__()
|
| 276 |
+
self.img_size = kwargs['img_size']
|
| 277 |
+
if not pca_dim:
|
| 278 |
+
pca_dim = None
|
| 279 |
+
self.dino = DINO(dino_model_type, device, self.img_size, pca_dim)
|
| 280 |
+
self.input_size = tuple(kwargs['output_shape'])
|
| 281 |
+
# input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size
|
| 282 |
+
# self.input_size = (input_size, input_size)
|
| 283 |
+
embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim
|
| 284 |
+
kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim})
|
| 285 |
+
self.adaptor = SAMEncoderAdaptor(**kwargs).to(device)
|
| 286 |
+
def extract_dino_features(self, x, transform=False, size = None):
|
| 287 |
+
return self.dino.extract_features(x, transform, size)
|
| 288 |
+
def forward(self, x, transform=False, size = None):
|
| 289 |
+
dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3)
|
| 290 |
+
adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1)
|
| 291 |
+
return self.adaptor(adaptor_input)
|
| 292 |
+
def setup_model(model_config):
|
| 293 |
+
prompt_embed_dim = 256
|
| 294 |
+
image_size = 1024
|
| 295 |
+
vit_patch_size = 16
|
| 296 |
+
image_embedding_size = image_size // vit_patch_size
|
| 297 |
+
model = eval(model_config.pop('type'))(**model_config)
|
| 298 |
+
if model.__class__.__name__ == 'SAMEncoderAdaptor':
|
| 299 |
+
adaptor = model
|
| 300 |
+
image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder
|
| 301 |
+
else:
|
| 302 |
+
adaptor = None
|
| 303 |
+
image_encoder = model
|
| 304 |
+
sam = Sam(
|
| 305 |
+
image_encoder=image_encoder,
|
| 306 |
+
prompt_encoder=PromptEncoder(
|
| 307 |
+
embed_dim=prompt_embed_dim,
|
| 308 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
| 309 |
+
input_image_size=(image_size, image_size),
|
| 310 |
+
mask_in_chans=16,
|
| 311 |
+
),
|
| 312 |
+
mask_decoder=MaskDecoder(
|
| 313 |
+
num_multimask_outputs=3,
|
| 314 |
+
transformer=TwoWayTransformer(
|
| 315 |
+
depth=2,
|
| 316 |
+
embedding_dim=prompt_embed_dim,
|
| 317 |
+
mlp_dim=2048,
|
| 318 |
+
num_heads=8,
|
| 319 |
+
),
|
| 320 |
+
transformer_dim=prompt_embed_dim,
|
| 321 |
+
iou_head_depth=3,
|
| 322 |
+
iou_head_hidden_dim=256,
|
| 323 |
+
),
|
| 324 |
+
adaptor=adaptor,
|
| 325 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
| 326 |
+
pixel_std=[58.395, 57.12, 57.375],
|
| 327 |
+
)
|
| 328 |
+
return sam
|
| 329 |
+
|
| 330 |
+
def load_distillation_sam(distillation_sam_ckpt_path,
|
| 331 |
+
device='cuda'):
|
| 332 |
+
ckpt = torch.load(distillation_sam_ckpt_path)
|
| 333 |
+
sam = setup_model(ckpt['model_config'])
|
| 334 |
+
sam.load_state_dict(ckpt['model'])
|
| 335 |
+
return sam.to(device)
|
| 336 |
+
|
| 337 |
+
def load_sam(sam_ckpt_path, sam_version, device):
|
| 338 |
+
if not os.path.exists(sam_ckpt_path):
|
| 339 |
+
parent_dir = os.path.dirname(sam_ckpt_path)
|
| 340 |
+
os.makedirs(parent_dir, exist_ok=True)
|
| 341 |
+
hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir)
|
| 342 |
+
if sam_version == 'sam':
|
| 343 |
+
sam = build_sam(sam_ckpt_path).to(device)
|
| 344 |
+
elif sam_version == 'mobile_sam':
|
| 345 |
+
sam = load_mobile_sam(sam_ckpt_path, device)
|
| 346 |
+
elif sam_version == 'distillation_sam':
|
| 347 |
+
sam = load_distillation_sam(sam_ckpt_path, device)
|
| 348 |
+
else:
|
| 349 |
+
raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]')
|
| 350 |
+
return sam
|
| 351 |
+
|
| 352 |
+
if __name__ == '__main__':
|
| 353 |
+
from distillation.utils import get_parameter_number
|
| 354 |
+
vit = SAMEncoderViT(depth=3,
|
| 355 |
+
embed_dim=256,
|
| 356 |
+
img_size=512,
|
| 357 |
+
mlp_ratio=4,
|
| 358 |
+
num_heads=16,
|
| 359 |
+
patch_size=8,
|
| 360 |
+
qkv_bias=True,
|
| 361 |
+
use_rel_pos=True,
|
| 362 |
+
global_attn_indexes=[1],
|
| 363 |
+
window_size=16,
|
| 364 |
+
out_chans=256,
|
| 365 |
+
multi_scale=False,
|
| 366 |
+
output_shape='').cuda()
|
| 367 |
+
x = torch.randn((1, 3, 512, 512)).cuda()
|
| 368 |
+
print(vit(x).shape)
|
| 369 |
+
print(get_parameter_number(vit))
|
sam_extension/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import Pipeline
|
| 2 |
+
from .sam import SAMEncoderPipeline, SAMDecoderPipeline
|
| 3 |
+
from .owlvit import OwlViTVisionEncoderPipeline, OwlViTDecoderPipeline
|
| 4 |
+
from .groundingdino import GroundingDinoPipeline
|
sam_extension/pipeline/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (421 Bytes). View file
|
|
|
sam_extension/pipeline/__pycache__/base.cpython-38.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
sam_extension/pipeline/__pycache__/groundingdino.cpython-38.pyc
ADDED
|
Binary file (3.28 kB). View file
|
|
|
sam_extension/pipeline/__pycache__/owlvit.cpython-38.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
sam_extension/pipeline/__pycache__/sam.cpython-38.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
sam_extension/pipeline/base.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from typing import Union, Dict
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
@dataclass(repr=True)
|
| 7 |
+
class Output:
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
class Pipeline(nn.Module):
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
super(Pipeline, self).__init__()
|
| 13 |
+
self.args = args
|
| 14 |
+
self.kwargs = kwargs
|
| 15 |
+
@classmethod
|
| 16 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
| 17 |
+
pass
|
| 18 |
+
def forward(self, *args, **kwargs):
|
| 19 |
+
pass
|
| 20 |
+
|
sam_extension/pipeline/groundingdino.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import functools
|
| 3 |
+
import PIL
|
| 4 |
+
from PIL.Image import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
import supervision as sv
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from sam_extension.pipeline import Pipeline
|
| 14 |
+
from groundingdino.util.inference import Model
|
| 15 |
+
|
| 16 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
| 17 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth"
|
| 18 |
+
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
|
| 19 |
+
LOCAL_DIR = "weights/groundingdino"
|
| 20 |
+
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True)
|
| 21 |
+
class GroundingDinoPipeline(Pipeline):
|
| 22 |
+
def __init__(self,
|
| 23 |
+
grounding_dino_config_path,
|
| 24 |
+
grounfing_dino_ckpt_path,
|
| 25 |
+
grounding_dino_model,
|
| 26 |
+
device,
|
| 27 |
+
*args,
|
| 28 |
+
**kwargs):
|
| 29 |
+
super(GroundingDinoPipeline, self).__init__(*args, **kwargs)
|
| 30 |
+
self.grounding_dino_config_path = grounding_dino_config_path
|
| 31 |
+
self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path
|
| 32 |
+
self.grounding_dino_model = grounding_dino_model
|
| 33 |
+
self.device = device
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs):
|
| 38 |
+
if not os.path.exists(grounfing_dino_ckpt_path):
|
| 39 |
+
hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path))
|
| 40 |
+
grounding_dino_model = Model(model_config_path=grounding_dino_config_path,
|
| 41 |
+
model_checkpoint_path=grounfing_dino_ckpt_path,
|
| 42 |
+
device=device)
|
| 43 |
+
return cls(grounding_dino_config_path,
|
| 44 |
+
grounfing_dino_ckpt_path,
|
| 45 |
+
grounding_dino_model,
|
| 46 |
+
device,
|
| 47 |
+
*args,
|
| 48 |
+
**kwargs)
|
| 49 |
+
|
| 50 |
+
def visualize_results(self,
|
| 51 |
+
img: Union[Image, np.ndarray],
|
| 52 |
+
class_list: [List],
|
| 53 |
+
box_threshold: float=0.25,
|
| 54 |
+
text_threshold: float=0.25,
|
| 55 |
+
nms_threshold: float=0.8,
|
| 56 |
+
pil: bool=True):
|
| 57 |
+
detections = self.forward(img, class_list, box_threshold, text_threshold)
|
| 58 |
+
box_annotator = sv.BoxAnnotator()
|
| 59 |
+
nms_idx = torchvision.ops.nms(
|
| 60 |
+
torch.from_numpy(detections.xyxy),
|
| 61 |
+
torch.from_numpy(detections.confidence),
|
| 62 |
+
nms_threshold
|
| 63 |
+
).numpy().tolist()
|
| 64 |
+
|
| 65 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
| 66 |
+
detections.confidence = detections.confidence[nms_idx]
|
| 67 |
+
detections.class_id = detections.class_id[nms_idx]
|
| 68 |
+
labels = [
|
| 69 |
+
f"{class_list[class_id]} {confidence:0.2f}"
|
| 70 |
+
for _, _, confidence, class_id, _
|
| 71 |
+
in detections]
|
| 72 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
|
| 73 |
+
if pil:
|
| 74 |
+
return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections
|
| 75 |
+
else:
|
| 76 |
+
return annotated_frame, detections
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def forward(self,
|
| 81 |
+
img: Union[Image, np.ndarray],
|
| 82 |
+
class_list: [List],
|
| 83 |
+
box_threshold: float=0.25,
|
| 84 |
+
text_threshold: float=0.25
|
| 85 |
+
)->sv.Detections:
|
| 86 |
+
if isinstance(img, Image):
|
| 87 |
+
img = np.uint8(img)[:, :, ::-1]
|
| 88 |
+
detections = self.grounding_dino_model.predict_with_classes(
|
| 89 |
+
image=img,
|
| 90 |
+
classes=class_list,
|
| 91 |
+
box_threshold=box_threshold,
|
| 92 |
+
text_threshold=text_threshold
|
| 93 |
+
)
|
| 94 |
+
return detections
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
sam_extension/pipeline/owlvit.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union, List
|
| 2 |
+
import numpy as np
|
| 3 |
+
import PIL
|
| 4 |
+
from PIL.Image import Image
|
| 5 |
+
import supervision as sv
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTVisionModel
|
| 11 |
+
from transformers.models.owlvit.modeling_owlvit import center_to_corners_format, box_iou, generalized_box_iou, OwlViTObjectDetectionOutput
|
| 12 |
+
|
| 13 |
+
from sam_extension.pipeline.base import Pipeline, Output
|
| 14 |
+
|
| 15 |
+
class OwlViTVisionEncoderPipeline(Pipeline):
|
| 16 |
+
|
| 17 |
+
def __init__(self,
|
| 18 |
+
vision_model,
|
| 19 |
+
layer_norm,
|
| 20 |
+
processor,
|
| 21 |
+
device='cuda',
|
| 22 |
+
*args,
|
| 23 |
+
**kwargs):
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
self.vision_model = vision_model
|
| 26 |
+
self.layer_norm = layer_norm
|
| 27 |
+
self.processor = processor
|
| 28 |
+
self.device = device
|
| 29 |
+
torch.cuda.empty_cache()
|
| 30 |
+
@classmethod
|
| 31 |
+
def from_pretrained(cls, model_type, device='cuda', *args, **kwargs):
|
| 32 |
+
owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device)
|
| 33 |
+
processor = OwlViTProcessor.from_pretrained(model_type)
|
| 34 |
+
return cls(owlvit_for_object_detection.owlvit.vision_model,
|
| 35 |
+
owlvit_for_object_detection.layer_norm,
|
| 36 |
+
processor,
|
| 37 |
+
device,
|
| 38 |
+
*args,
|
| 39 |
+
**kwargs)
|
| 40 |
+
def process_image(self, image:Image):
|
| 41 |
+
image = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
|
| 42 |
+
return image
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def forward(
|
| 45 |
+
self,
|
| 46 |
+
pixel_values: Union[torch.FloatTensor, Image] = None,
|
| 47 |
+
output_attentions: Optional[bool] = None,
|
| 48 |
+
output_hidden_states: Optional[bool] = None,
|
| 49 |
+
return_dict: Optional[bool] = None,
|
| 50 |
+
) -> torch.FloatTensor:
|
| 51 |
+
if isinstance(pixel_values, Image):
|
| 52 |
+
pixel_values = self.process_image(pixel_values)
|
| 53 |
+
pixel_values = pixel_values.to(self.device)
|
| 54 |
+
vision_outputs = self.vision_model(
|
| 55 |
+
pixel_values=pixel_values,
|
| 56 |
+
output_attentions=output_attentions,
|
| 57 |
+
output_hidden_states=output_hidden_states,
|
| 58 |
+
return_dict=return_dict,
|
| 59 |
+
)
|
| 60 |
+
# Get image embeddings
|
| 61 |
+
last_hidden_state = vision_outputs[0]
|
| 62 |
+
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
|
| 63 |
+
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
| 64 |
+
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
| 65 |
+
|
| 66 |
+
# Merge image embedding with class tokens
|
| 67 |
+
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
| 68 |
+
image_embeds = self.layer_norm(image_embeds)
|
| 69 |
+
|
| 70 |
+
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
| 71 |
+
new_size = (
|
| 72 |
+
image_embeds.shape[0],
|
| 73 |
+
int(np.sqrt(image_embeds.shape[1])),
|
| 74 |
+
int(np.sqrt(image_embeds.shape[1])),
|
| 75 |
+
image_embeds.shape[-1],
|
| 76 |
+
)
|
| 77 |
+
image_embeds = image_embeds.reshape(new_size)
|
| 78 |
+
return image_embeds
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class OwlViTDecoderPipeline(Pipeline):
|
| 83 |
+
prompt_template: str = 'a photo of a '
|
| 84 |
+
def __init__(self,
|
| 85 |
+
owlvit_text,
|
| 86 |
+
text_projection,
|
| 87 |
+
class_head,
|
| 88 |
+
box_head,
|
| 89 |
+
processor,
|
| 90 |
+
device='cuda',
|
| 91 |
+
*args,
|
| 92 |
+
**kwargs):
|
| 93 |
+
super().__init__(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
self.owlvit_text = owlvit_text
|
| 96 |
+
self.text_projection = text_projection
|
| 97 |
+
self.class_head = class_head
|
| 98 |
+
self.box_head = box_head
|
| 99 |
+
|
| 100 |
+
self.sigmoid = nn.Sigmoid()
|
| 101 |
+
self.processor = processor
|
| 102 |
+
self.device = device
|
| 103 |
+
torch.cuda.empty_cache()
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def from_pretrained(cls, model_type, device='cuda', *args, **kwargs):
|
| 107 |
+
owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device)
|
| 108 |
+
processor = OwlViTProcessor.from_pretrained(model_type)
|
| 109 |
+
return cls(owlvit_for_object_detection.owlvit.text_model,
|
| 110 |
+
owlvit_for_object_detection.owlvit.text_projection,
|
| 111 |
+
owlvit_for_object_detection.class_head,
|
| 112 |
+
owlvit_for_object_detection.box_head,
|
| 113 |
+
processor,
|
| 114 |
+
device,
|
| 115 |
+
*args,
|
| 116 |
+
**kwargs)
|
| 117 |
+
def set_template(self, template: str):
|
| 118 |
+
self.prompt_template = template
|
| 119 |
+
def process_text(self, text:List, use_template:bool = True):
|
| 120 |
+
if use_template:
|
| 121 |
+
text = [[self.prompt_template+i for i in text[0]]]
|
| 122 |
+
inputs = self.processor(text=text, return_tensors="pt")
|
| 123 |
+
return inputs
|
| 124 |
+
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
|
| 125 |
+
# Computes normalized xy corner coordinates from feature_map.
|
| 126 |
+
if not feature_map.ndim == 4:
|
| 127 |
+
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
|
| 128 |
+
|
| 129 |
+
device = feature_map.device
|
| 130 |
+
num_patches = feature_map.shape[1]
|
| 131 |
+
|
| 132 |
+
box_coordinates = np.stack(
|
| 133 |
+
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
|
| 134 |
+
).astype(np.float32)
|
| 135 |
+
box_coordinates /= np.array([num_patches, num_patches], np.float32)
|
| 136 |
+
|
| 137 |
+
# Flatten (h, w, 2) -> (h*w, 2)
|
| 138 |
+
box_coordinates = box_coordinates.reshape(
|
| 139 |
+
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
|
| 140 |
+
)
|
| 141 |
+
box_coordinates = torch.from_numpy(box_coordinates).to(device)
|
| 142 |
+
|
| 143 |
+
return box_coordinates
|
| 144 |
+
|
| 145 |
+
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
|
| 146 |
+
# The box center is biased to its position on the feature grid
|
| 147 |
+
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
|
| 148 |
+
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
|
| 149 |
+
|
| 150 |
+
# Unnormalize xy
|
| 151 |
+
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
|
| 152 |
+
|
| 153 |
+
# The box size is biased to the patch size
|
| 154 |
+
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
|
| 155 |
+
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
|
| 156 |
+
|
| 157 |
+
# Compute box bias
|
| 158 |
+
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
|
| 159 |
+
return box_bias
|
| 160 |
+
|
| 161 |
+
def box_predictor(
|
| 162 |
+
self,
|
| 163 |
+
image_feats: torch.FloatTensor,
|
| 164 |
+
feature_map: torch.FloatTensor,
|
| 165 |
+
) -> torch.FloatTensor:
|
| 166 |
+
"""
|
| 167 |
+
Args:
|
| 168 |
+
image_feats:
|
| 169 |
+
Features extracted from the image, returned by the `image_text_embedder` method.
|
| 170 |
+
feature_map:
|
| 171 |
+
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
|
| 172 |
+
Returns:
|
| 173 |
+
pred_boxes:
|
| 174 |
+
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
|
| 175 |
+
"""
|
| 176 |
+
# Bounding box detection head [batch_size, num_boxes, 4].
|
| 177 |
+
pred_boxes = self.box_head(image_feats)
|
| 178 |
+
|
| 179 |
+
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
|
| 180 |
+
pred_boxes += self.compute_box_bias(feature_map)
|
| 181 |
+
pred_boxes = self.sigmoid(pred_boxes)
|
| 182 |
+
return pred_boxes
|
| 183 |
+
|
| 184 |
+
def class_predictor(
|
| 185 |
+
self,
|
| 186 |
+
image_feats: torch.FloatTensor,
|
| 187 |
+
query_embeds: Optional[torch.FloatTensor] = None,
|
| 188 |
+
query_mask: Optional[torch.Tensor] = None,
|
| 189 |
+
) -> Tuple[torch.FloatTensor]:
|
| 190 |
+
"""
|
| 191 |
+
Args:
|
| 192 |
+
image_feats:
|
| 193 |
+
Features extracted from the `image_text_embedder`.
|
| 194 |
+
query_embeds:
|
| 195 |
+
Text query embeddings.
|
| 196 |
+
query_mask:
|
| 197 |
+
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
|
| 198 |
+
"""
|
| 199 |
+
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
|
| 200 |
+
|
| 201 |
+
return (pred_logits, image_class_embeds)
|
| 202 |
+
|
| 203 |
+
def image_text_embedder(
|
| 204 |
+
self,
|
| 205 |
+
input_ids: torch.Tensor,
|
| 206 |
+
image_embeds: torch.FloatTensor,
|
| 207 |
+
attention_mask: torch.Tensor,
|
| 208 |
+
output_attentions: Optional[bool] = None,
|
| 209 |
+
output_hidden_states: Optional[bool] = None,
|
| 210 |
+
) -> Tuple[torch.FloatTensor]:
|
| 211 |
+
|
| 212 |
+
# Encode text and image
|
| 213 |
+
text_outputs = self.owlvit_text(
|
| 214 |
+
input_ids=input_ids,
|
| 215 |
+
attention_mask=attention_mask,
|
| 216 |
+
output_attentions=output_attentions,
|
| 217 |
+
output_hidden_states=output_hidden_states,
|
| 218 |
+
return_dict=True,
|
| 219 |
+
)
|
| 220 |
+
text_embeds = text_outputs[1]
|
| 221 |
+
text_embeds = self.text_projection(text_embeds)
|
| 222 |
+
text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
| 223 |
+
|
| 224 |
+
return (text_embeds, image_embeds, text_outputs)
|
| 225 |
+
|
| 226 |
+
def embed_image_query(
|
| 227 |
+
self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor
|
| 228 |
+
) -> torch.FloatTensor:
|
| 229 |
+
|
| 230 |
+
_, class_embeds = self.class_predictor(query_image_features)
|
| 231 |
+
pred_boxes = self.box_predictor(query_image_features, query_feature_map)
|
| 232 |
+
pred_boxes_as_corners = center_to_corners_format(pred_boxes)
|
| 233 |
+
|
| 234 |
+
# Loop over query images
|
| 235 |
+
best_class_embeds = []
|
| 236 |
+
best_box_indices = []
|
| 237 |
+
pred_boxes_device = pred_boxes_as_corners.device
|
| 238 |
+
|
| 239 |
+
for i in range(query_image_features.shape[0]):
|
| 240 |
+
each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
|
| 241 |
+
each_query_pred_boxes = pred_boxes_as_corners[i]
|
| 242 |
+
ious, _ = box_iou(each_query_box, each_query_pred_boxes)
|
| 243 |
+
|
| 244 |
+
# If there are no overlapping boxes, fall back to generalized IoU
|
| 245 |
+
if torch.all(ious[0] == 0.0):
|
| 246 |
+
ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
|
| 247 |
+
|
| 248 |
+
# Use an adaptive threshold to include all boxes within 80% of the best IoU
|
| 249 |
+
iou_threshold = torch.max(ious) * 0.8
|
| 250 |
+
|
| 251 |
+
selected_inds = (ious[0] >= iou_threshold).nonzero()
|
| 252 |
+
if selected_inds.numel():
|
| 253 |
+
selected_embeddings = class_embeds[i][selected_inds[0]]
|
| 254 |
+
mean_embeds = torch.mean(class_embeds[i], axis=0)
|
| 255 |
+
mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
|
| 256 |
+
best_box_ind = selected_inds[torch.argmin(mean_sim)]
|
| 257 |
+
best_class_embeds.append(class_embeds[i][best_box_ind])
|
| 258 |
+
best_box_indices.append(best_box_ind)
|
| 259 |
+
|
| 260 |
+
if best_class_embeds:
|
| 261 |
+
query_embeds = torch.stack(best_class_embeds)
|
| 262 |
+
box_indices = torch.stack(best_box_indices)
|
| 263 |
+
else:
|
| 264 |
+
query_embeds, box_indices = None, None
|
| 265 |
+
|
| 266 |
+
return query_embeds, box_indices, pred_boxes
|
| 267 |
+
|
| 268 |
+
@torch.no_grad()
|
| 269 |
+
def forward(
|
| 270 |
+
self,
|
| 271 |
+
image_embeds: torch.FloatTensor,
|
| 272 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 273 |
+
text: Optional[List] = None,
|
| 274 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 275 |
+
output_attentions: Optional[bool] = None,
|
| 276 |
+
output_hidden_states: Optional[bool] = None,
|
| 277 |
+
return_dict: Optional[bool] = None,
|
| 278 |
+
) -> OwlViTObjectDetectionOutput:
|
| 279 |
+
if text is not None:
|
| 280 |
+
inputs = self.process_text(text)
|
| 281 |
+
input_ids = inputs.input_ids.to(self.device)
|
| 282 |
+
attention_mask = inputs.attention_mask.to(self.device)
|
| 283 |
+
input_ids = input_ids.to(self.device)
|
| 284 |
+
image_embeds = image_embeds.to(self.device)
|
| 285 |
+
attention_mask = attention_mask.to(self.device)
|
| 286 |
+
output_attentions = output_attentions if output_attentions is not None else False
|
| 287 |
+
output_hidden_states = (
|
| 288 |
+
output_hidden_states if output_hidden_states is not None else False
|
| 289 |
+
)
|
| 290 |
+
return_dict = return_dict if return_dict is not None else True
|
| 291 |
+
|
| 292 |
+
# Embed images and text queries
|
| 293 |
+
query_embeds, feature_map, text_outputs = self.image_text_embedder(
|
| 294 |
+
input_ids=input_ids,
|
| 295 |
+
image_embeds=image_embeds,
|
| 296 |
+
attention_mask=attention_mask,
|
| 297 |
+
output_attentions=output_attentions,
|
| 298 |
+
output_hidden_states=output_hidden_states,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Text and vision model outputs
|
| 302 |
+
|
| 303 |
+
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
| 304 |
+
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
| 305 |
+
|
| 306 |
+
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
|
| 307 |
+
max_text_queries = input_ids.shape[0] // batch_size
|
| 308 |
+
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
|
| 309 |
+
|
| 310 |
+
# If first token is 0, then this is a padded query [batch_size, num_queries].
|
| 311 |
+
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
|
| 312 |
+
query_mask = input_ids[..., 0] > 0
|
| 313 |
+
|
| 314 |
+
# Predict object classes [batch_size, num_patches, num_queries+1]
|
| 315 |
+
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
|
| 316 |
+
|
| 317 |
+
# Predict object boxes
|
| 318 |
+
pred_boxes = self.box_predictor(image_feats, feature_map)
|
| 319 |
+
|
| 320 |
+
if not return_dict:
|
| 321 |
+
output = (
|
| 322 |
+
pred_logits,
|
| 323 |
+
pred_boxes,
|
| 324 |
+
query_embeds,
|
| 325 |
+
feature_map,
|
| 326 |
+
class_embeds,
|
| 327 |
+
text_outputs.to_tuple(),
|
| 328 |
+
None,
|
| 329 |
+
)
|
| 330 |
+
output = tuple(x for x in output if x is not None)
|
| 331 |
+
return output
|
| 332 |
+
|
| 333 |
+
return OwlViTObjectDetectionOutput(
|
| 334 |
+
image_embeds=feature_map,
|
| 335 |
+
text_embeds=query_embeds,
|
| 336 |
+
pred_boxes=pred_boxes.cpu(),
|
| 337 |
+
logits=pred_logits.cpu(),
|
| 338 |
+
class_embeds=class_embeds,
|
| 339 |
+
text_model_output=text_outputs,
|
| 340 |
+
vision_model_output=None,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def owlvit_visualize(self,
|
| 344 |
+
image: Image,
|
| 345 |
+
texts: List,
|
| 346 |
+
owlvit_objectdetection_output: OwlViTObjectDetectionOutput,
|
| 347 |
+
score_threshold: float = 0.1,
|
| 348 |
+
pil=True):
|
| 349 |
+
target_sizes = torch.Tensor([image.size[::-1]])
|
| 350 |
+
# Convert outputs (bounding boxes and class logits) to COCO API
|
| 351 |
+
results = self.processor.post_process(outputs=owlvit_objectdetection_output, target_sizes=target_sizes)
|
| 352 |
+
|
| 353 |
+
text = texts[0]
|
| 354 |
+
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
| 355 |
+
boxes_np = []
|
| 356 |
+
labels_list = []
|
| 357 |
+
# Print detected objects and rescaled box coordinates
|
| 358 |
+
for box, score, label in zip(boxes, scores, labels):
|
| 359 |
+
box = [int(i) for i in box.tolist()]
|
| 360 |
+
if score >= score_threshold:
|
| 361 |
+
labels_list.append(f"{text[label]} {round(score.item(), 3)}")
|
| 362 |
+
boxes_np.append(box)
|
| 363 |
+
print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
| 364 |
+
boxes_np = np.array(boxes_np)
|
| 365 |
+
detections = sv.Detections(xyxy=boxes_np)
|
| 366 |
+
image_np = np.uint8(image)[:, :, ::-1]
|
| 367 |
+
box_annotator = sv.BoxAnnotator()
|
| 368 |
+
annotated_frame = box_annotator.annotate(scene=image_np.copy(), detections=detections, labels=labels_list)
|
| 369 |
+
if pil:
|
| 370 |
+
return PIL.Image.fromarray(annotated_frame[:, :, ::-1])
|
| 371 |
+
else:
|
| 372 |
+
return annotated_frame[:, :, ::-1]
|
sam_extension/pipeline/sam.py
ADDED
|
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import PIL
|
| 4 |
+
from PIL.Image import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Union, Tuple, List, Optional, Callable
|
| 7 |
+
from sklearn.decomposition import PCA
|
| 8 |
+
import supervision as sv
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torchvision
|
| 14 |
+
import torchvision.transforms as T
|
| 15 |
+
|
| 16 |
+
from segment_anything.utils.transforms import ResizeLongestSide
|
| 17 |
+
from segment_anything.predictor import preprocess, postprocess_masks
|
| 18 |
+
from segment_anything import build_sam, load_mobile_sam
|
| 19 |
+
|
| 20 |
+
from sam_extension.utils import add_prompts_tag, get_empty_detections, transform_coords
|
| 21 |
+
from sam_extension.pipeline.base import Pipeline, Output
|
| 22 |
+
from sam_extension.pipeline.groundingdino import GroundingDinoPipeline
|
| 23 |
+
from sam_extension.distillation_models.sam import load_distillation_sam, load_sam
|
| 24 |
+
from sam_extension.distillation_models import *
|
| 25 |
+
|
| 26 |
+
ORIGINAL_SAM_IMG_SIZE: int = 1024
|
| 27 |
+
PIXEL_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
| 28 |
+
PIXEL_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
| 29 |
+
PREPROCESS = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN, PIXEL_STD)
|
| 30 |
+
POSTPROCESS_MASKS = functools.partial(postprocess_masks, ORIGINAL_SAM_IMG_SIZE)
|
| 31 |
+
|
| 32 |
+
@dataclass(repr=True)
|
| 33 |
+
class SAMEncoderOutput(Output):
|
| 34 |
+
features: torch.Tensor
|
| 35 |
+
interm_features: List[torch.Tensor]
|
| 36 |
+
original_size: Tuple
|
| 37 |
+
input_size: Tuple
|
| 38 |
+
|
| 39 |
+
@dataclass(repr=True)
|
| 40 |
+
class SAMEncoderProcesImgOutput(Output):
|
| 41 |
+
input_image: torch.Tensor
|
| 42 |
+
original_size: Tuple
|
| 43 |
+
input_size: Tuple
|
| 44 |
+
|
| 45 |
+
@dataclass(repr=True)
|
| 46 |
+
class SAMDecoderPredictOutput(Output):
|
| 47 |
+
masks_np: np.ndarray
|
| 48 |
+
iou_predictions_np: np.ndarray
|
| 49 |
+
low_res_masks_np: np.ndarray
|
| 50 |
+
|
| 51 |
+
@dataclass(repr=True)
|
| 52 |
+
class SAMDecoderPredictTorchOutput(Output):
|
| 53 |
+
masks: torch.Tensor
|
| 54 |
+
iou_predictions: torch.Tensor
|
| 55 |
+
low_res_masks: torch.Tensor
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SAMEncoderPipeline(Pipeline):
|
| 59 |
+
def __init__(self,
|
| 60 |
+
encoder: nn.Module,
|
| 61 |
+
input_img_size: Tuple,
|
| 62 |
+
multi_output: bool,
|
| 63 |
+
preprocess: Callable,
|
| 64 |
+
transform: ResizeLongestSide,
|
| 65 |
+
device: str,
|
| 66 |
+
*args,
|
| 67 |
+
**kwargs):
|
| 68 |
+
super(SAMEncoderPipeline, self).__init__(*args, **kwargs)
|
| 69 |
+
self.encoder = encoder
|
| 70 |
+
self.input_img_size = input_img_size
|
| 71 |
+
self.multi_output = multi_output
|
| 72 |
+
self.preprocess = preprocess
|
| 73 |
+
self.transform = transform
|
| 74 |
+
self.device = device
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
| 77 |
+
if 'sam_version' not in kwargs.keys():
|
| 78 |
+
sam_version = 'sam'
|
| 79 |
+
else:
|
| 80 |
+
sam_version = kwargs['sam_version']
|
| 81 |
+
sam = load_sam(ckpt_path, sam_version, device)
|
| 82 |
+
encoder = sam.image_encoder
|
| 83 |
+
encoder_type = encoder.__class__.__name__
|
| 84 |
+
if encoder_type in ['TinyViT', 'FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']:
|
| 85 |
+
multi_output = False
|
| 86 |
+
if encoder_type in ['FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']:
|
| 87 |
+
input_img_size = (encoder.img_size, encoder.img_size)
|
| 88 |
+
if encoder_type == 'DINOSAMViT':
|
| 89 |
+
encoder = encoder.dino
|
| 90 |
+
else:
|
| 91 |
+
input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE)
|
| 92 |
+
else:
|
| 93 |
+
multi_output = True
|
| 94 |
+
input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE)
|
| 95 |
+
if sam.adaptor is None:
|
| 96 |
+
transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE)
|
| 97 |
+
preprocess_ = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN.to(device), PIXEL_STD.to(device))
|
| 98 |
+
else:
|
| 99 |
+
transform = T.Compose([
|
| 100 |
+
T.Resize(input_img_size),
|
| 101 |
+
T.ToTensor(),
|
| 102 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
| 103 |
+
])
|
| 104 |
+
preprocess_ = None
|
| 105 |
+
pipeline = cls(encoder=encoder,
|
| 106 |
+
input_img_size=input_img_size,
|
| 107 |
+
multi_output=multi_output,
|
| 108 |
+
preprocess=preprocess_,
|
| 109 |
+
transform=transform,
|
| 110 |
+
device=device)
|
| 111 |
+
del sam, encoder
|
| 112 |
+
torch.cuda.empty_cache()
|
| 113 |
+
return pipeline
|
| 114 |
+
|
| 115 |
+
def process_img(self, img: Union[Image, np.ndarray]) -> SAMEncoderProcesImgOutput:
|
| 116 |
+
if self.preprocess is not None:
|
| 117 |
+
if isinstance(img, Image):
|
| 118 |
+
img = np.uint8(img)
|
| 119 |
+
input_image = self.transform.apply_image(img)
|
| 120 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
| 121 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
| 122 |
+
original_size = tuple(img.shape[:2])
|
| 123 |
+
input_size = tuple(input_image_torch.shape[-2:])
|
| 124 |
+
input_image = F.interpolate(self.preprocess(input_image_torch), size=self.input_img_size, mode='bilinear')
|
| 125 |
+
else:
|
| 126 |
+
if isinstance(img, np.ndarray):
|
| 127 |
+
img = PIL.Image.fromarray(img)
|
| 128 |
+
original_size = (img.size[1], img.size[0])
|
| 129 |
+
if original_size[0] > original_size[1]:
|
| 130 |
+
input_h = 1024
|
| 131 |
+
input_w = int((1024 / original_size[0]) * original_size[1])
|
| 132 |
+
else:
|
| 133 |
+
input_w = 1024
|
| 134 |
+
input_h = int((1024 / original_size[1]) * original_size[0])
|
| 135 |
+
input_size = (input_h, input_w)
|
| 136 |
+
input_image = self.transform(img)[None, ...].to(self.device)
|
| 137 |
+
return SAMEncoderProcesImgOutput(input_image, original_size, input_size)
|
| 138 |
+
@torch.no_grad()
|
| 139 |
+
def get_visual_feature(self, x: Union[torch.Tensor, Image, np.ndarray]=None, **kwargs):
|
| 140 |
+
pca_rgb = PCA(n_components=3)
|
| 141 |
+
if 'sam_feature' in kwargs.keys() and 'original_size' in kwargs.keys():
|
| 142 |
+
sam_feature = kwargs['sam_feature']
|
| 143 |
+
original_size = kwargs['original_size']
|
| 144 |
+
else:
|
| 145 |
+
assert x is not None, 'please give x type Union[torch.Tensor, Image, np.ndarray] !'
|
| 146 |
+
sam_encoder_output = self.forward(x, **kwargs)
|
| 147 |
+
sam_feature = sam_encoder_output.features
|
| 148 |
+
original_size = sam_encoder_output.original_size
|
| 149 |
+
assert original_size is not None, 'please give original_size!'
|
| 150 |
+
sam_feature = F.interpolate(sam_feature, size=original_size, mode='bilinear').permute(0, 2, 3, 1)
|
| 151 |
+
b, h, w, c = sam_feature.shape
|
| 152 |
+
sam_feature = sam_feature.view(-1, c).cpu().numpy()
|
| 153 |
+
sam_feature = pca_rgb.fit_transform(sam_feature)
|
| 154 |
+
sam_feature = torch.Tensor(sam_feature.reshape(h, w, 3))
|
| 155 |
+
min_f, _ = sam_feature.min(-1)
|
| 156 |
+
max_f, _ = sam_feature.max(-1)
|
| 157 |
+
sam_feature = (sam_feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None])
|
| 158 |
+
sam_feature = sam_feature.cpu().numpy()
|
| 159 |
+
sam_feature_image = PIL.Image.fromarray((sam_feature * 255).astype(np.uint8))
|
| 160 |
+
return sam_feature_image
|
| 161 |
+
def forward(self, x: Union[torch.Tensor, Image, np.ndarray], **kwargs) -> SAMEncoderOutput:
|
| 162 |
+
if isinstance(x, (Image, np.ndarray)):
|
| 163 |
+
process_img_output = self.process_img(x)
|
| 164 |
+
x = process_img_output.input_image
|
| 165 |
+
original_size = process_img_output.original_size
|
| 166 |
+
input_size = process_img_output.input_size
|
| 167 |
+
else:
|
| 168 |
+
original_size = kwargs.pop('original_size') if 'original_size' in kwargs.keys() else None
|
| 169 |
+
input_size = x.shape[-2:]
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
if self.multi_output:
|
| 172 |
+
features, interm_features = self.encoder(x, **kwargs)
|
| 173 |
+
else:
|
| 174 |
+
features = self.encoder(x, **kwargs)
|
| 175 |
+
if self.encoder.__class__.__name__ == 'DINO':
|
| 176 |
+
features = features.permute(0, 3, 1, 2)
|
| 177 |
+
interm_features = None
|
| 178 |
+
return SAMEncoderOutput(features, interm_features, original_size, input_size)
|
| 179 |
+
|
| 180 |
+
class SAMDecoderPipeline(Pipeline):
|
| 181 |
+
def __init__(self,
|
| 182 |
+
prompt_encoder: nn.Module,
|
| 183 |
+
mask_decoder: nn.Module,
|
| 184 |
+
adaptor: nn.Module,
|
| 185 |
+
mask_threshold: float,
|
| 186 |
+
transform: ResizeLongestSide,
|
| 187 |
+
postprocess_masks: Callable,
|
| 188 |
+
img_size: int,
|
| 189 |
+
device: str,
|
| 190 |
+
*args,
|
| 191 |
+
**kwargs):
|
| 192 |
+
super(SAMDecoderPipeline, self).__init__(*args, **kwargs)
|
| 193 |
+
self.prompt_encoder = prompt_encoder
|
| 194 |
+
self.mask_decoder = mask_decoder
|
| 195 |
+
self.adaptor = adaptor
|
| 196 |
+
self.mask_threshold = mask_threshold
|
| 197 |
+
self.transform = transform
|
| 198 |
+
self.postprocess_masks = postprocess_masks
|
| 199 |
+
self.img_size = img_size
|
| 200 |
+
self.device = device
|
| 201 |
+
@classmethod
|
| 202 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
| 203 |
+
if 'sam_version' not in kwargs.keys():
|
| 204 |
+
sam_version = 'sam'
|
| 205 |
+
else:
|
| 206 |
+
sam_version = kwargs['sam_version']
|
| 207 |
+
sam = load_sam(ckpt_path, sam_version, device)
|
| 208 |
+
if sam.image_encoder.__class__.__name__ == 'DINOSAMViT':
|
| 209 |
+
adaptor = sam.image_encoder.adaptor
|
| 210 |
+
elif sam.adaptor is not None:
|
| 211 |
+
adaptor = sam.adaptor
|
| 212 |
+
else:
|
| 213 |
+
adaptor = None
|
| 214 |
+
img_size = sam.image_encoder.img_size
|
| 215 |
+
prompt_encoder = sam.prompt_encoder
|
| 216 |
+
mask_decoder = sam.mask_decoder
|
| 217 |
+
transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE)
|
| 218 |
+
pipeline = cls(prompt_encoder=prompt_encoder,
|
| 219 |
+
mask_decoder=mask_decoder,
|
| 220 |
+
adaptor=adaptor,
|
| 221 |
+
mask_threshold=sam.mask_threshold,
|
| 222 |
+
transform=transform,
|
| 223 |
+
postprocess_masks=POSTPROCESS_MASKS,
|
| 224 |
+
img_size=img_size,
|
| 225 |
+
device=device)
|
| 226 |
+
del sam, prompt_encoder, mask_decoder
|
| 227 |
+
torch.cuda.empty_cache()
|
| 228 |
+
return pipeline
|
| 229 |
+
def visualize_prompt(self,
|
| 230 |
+
img: Union[Image, np.ndarray],
|
| 231 |
+
des_img: Union[Image, np.ndarray] = None,
|
| 232 |
+
point_labels: Union[List[int], np.ndarray] = None,
|
| 233 |
+
point_coords: Union[List[List[int]], np.ndarray] = None,
|
| 234 |
+
boxes: Union[List[List[int]], np.ndarray] = None,
|
| 235 |
+
pil: bool = False
|
| 236 |
+
) -> Union[Image, np.ndarray]:
|
| 237 |
+
if des_img is not None:
|
| 238 |
+
if isinstance(des_img, np.ndarray):
|
| 239 |
+
des_shape = tuple(des_img.shape[:2])
|
| 240 |
+
|
| 241 |
+
else:
|
| 242 |
+
des_shape = (des_img.size[1], des_img.size[0])
|
| 243 |
+
src_shape = (img.size[1], img.size[0])
|
| 244 |
+
point_coords, boxes = transform_coords(src_shape, des_shape, point_coords, boxes)
|
| 245 |
+
return add_prompts_tag(des_img, point_labels, point_coords, boxes, pil)
|
| 246 |
+
else:
|
| 247 |
+
return add_prompts_tag(img, point_labels, point_coords, boxes, pil)
|
| 248 |
+
|
| 249 |
+
def visualize_results(self,
|
| 250 |
+
img: Union[Image, np.ndarray],
|
| 251 |
+
des_img: Union[Image, np.ndarray] = None,
|
| 252 |
+
sam_encoder_output: Optional[SAMEncoderOutput] = None,
|
| 253 |
+
features: Optional[torch.Tensor] = None,
|
| 254 |
+
interm_features: Optional[List[torch.Tensor]] = None,
|
| 255 |
+
original_size: Optional[Tuple] = None,
|
| 256 |
+
input_size: Optional[Tuple] = None,
|
| 257 |
+
point_coords: Optional[np.ndarray] = None,
|
| 258 |
+
point_labels: Optional[np.ndarray] = None,
|
| 259 |
+
boxes: Optional[np.ndarray] = None,
|
| 260 |
+
texts: Optional[List] = None,
|
| 261 |
+
grounding_dino_pipeline: GroundingDinoPipeline = None,
|
| 262 |
+
box_threshold: float = 0.25,
|
| 263 |
+
text_threshold: float = 0.25,
|
| 264 |
+
nms_threshold: float = 0.8,
|
| 265 |
+
detections: Optional[sv.Detections] = None,
|
| 266 |
+
multimask_output: bool = True,
|
| 267 |
+
visualize_promts: bool = True,
|
| 268 |
+
pil: bool = False):
|
| 269 |
+
if isinstance(img, Image):
|
| 270 |
+
img = np.uint8(img)
|
| 271 |
+
if des_img is not None:
|
| 272 |
+
if isinstance(des_img, np.ndarray):
|
| 273 |
+
des_shape = tuple(des_img.shape[:2])
|
| 274 |
+
else:
|
| 275 |
+
des_shape = (des_img.size[1], des_img.size[0])
|
| 276 |
+
src_shape = img.shape[:2]
|
| 277 |
+
if point_coords is not None or boxes is not None:
|
| 278 |
+
des_point_coords, des_boxes = transform_coords(src_shape, des_shape, point_coords, boxes)
|
| 279 |
+
else:
|
| 280 |
+
des_point_coords = None
|
| 281 |
+
des_boxes = None
|
| 282 |
+
else:
|
| 283 |
+
des_point_coords = None
|
| 284 |
+
des_boxes = None
|
| 285 |
+
src_shape = None
|
| 286 |
+
des_shape = None
|
| 287 |
+
detections = get_empty_detections() if detections is None else detections
|
| 288 |
+
mask_annotator = sv.MaskAnnotator()
|
| 289 |
+
result_list = []
|
| 290 |
+
mask_result_list = []
|
| 291 |
+
mask_list = []
|
| 292 |
+
if boxes is None and point_coords is None and point_labels is None and texts is None or \
|
| 293 |
+
(point_coords is not None and point_labels is not None and point_coords.shape[0] != point_labels.shape[0]):
|
| 294 |
+
print('no prompt given!')
|
| 295 |
+
result_list.append(img)
|
| 296 |
+
return result_list
|
| 297 |
+
# if boxes is not None and point_coords is not None and point_labels is not None:
|
| 298 |
+
# multimask_output = False
|
| 299 |
+
def get_annotated_image(mask_annotator,
|
| 300 |
+
detections,
|
| 301 |
+
img,
|
| 302 |
+
point_labels=None,
|
| 303 |
+
point_coords=None,
|
| 304 |
+
boxes=None,
|
| 305 |
+
visualize_promts=True,
|
| 306 |
+
pil=False):
|
| 307 |
+
annotated_image = mask_annotator.annotate(scene=img.copy(), detections=detections)
|
| 308 |
+
if visualize_promts:
|
| 309 |
+
annotated_image = add_prompts_tag(annotated_image, point_labels, point_coords, boxes=boxes, pil=pil)
|
| 310 |
+
else:
|
| 311 |
+
if pil:
|
| 312 |
+
annotated_image = PIL.Image.fromarray(annotated_image)
|
| 313 |
+
return annotated_image
|
| 314 |
+
def get_masked_image(img,
|
| 315 |
+
masks,
|
| 316 |
+
pil=True):
|
| 317 |
+
masked_image_list = []
|
| 318 |
+
for i in range(masks.shape[0]):
|
| 319 |
+
object_rgb = img * (masks[i].reshape(img.shape[0], img.shape[1], 1))
|
| 320 |
+
object_rgb = object_rgb.astype(np.uint8)
|
| 321 |
+
bkgd_mask = np.where(object_rgb == 0, 1, 0)
|
| 322 |
+
bkgd_mask *= 255
|
| 323 |
+
bkgd_mask = bkgd_mask.astype(np.uint8)
|
| 324 |
+
object_rgb += bkgd_mask
|
| 325 |
+
if pil:
|
| 326 |
+
masked_image_list.append(PIL.Image.fromarray(object_rgb))
|
| 327 |
+
else:
|
| 328 |
+
masked_image_list.append(object_rgb)
|
| 329 |
+
return masked_image_list
|
| 330 |
+
def interpolate_mask(mask_np, des_shape):
|
| 331 |
+
mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0)
|
| 332 |
+
mask_interpolate = F.interpolate(mask_tensor, size=des_shape, mode='bilinear')
|
| 333 |
+
mask_interpolate = (mask_interpolate+0.5).long()
|
| 334 |
+
mask_np = mask_interpolate.squeeze(0).numpy().astype(bool)
|
| 335 |
+
return mask_np
|
| 336 |
+
|
| 337 |
+
if point_coords is not None and point_labels is not None:
|
| 338 |
+
|
| 339 |
+
if src_shape is not None:
|
| 340 |
+
point_result = self.forward(sam_encoder_output,
|
| 341 |
+
features,
|
| 342 |
+
interm_features,
|
| 343 |
+
original_size,
|
| 344 |
+
input_size,
|
| 345 |
+
des_point_coords,
|
| 346 |
+
point_labels)
|
| 347 |
+
masks_np = interpolate_mask(point_result.masks_np, src_shape)
|
| 348 |
+
else:
|
| 349 |
+
point_result = self.forward(sam_encoder_output,
|
| 350 |
+
features,
|
| 351 |
+
interm_features,
|
| 352 |
+
original_size,
|
| 353 |
+
input_size,
|
| 354 |
+
point_coords,
|
| 355 |
+
point_labels)
|
| 356 |
+
masks_np = point_result.masks_np
|
| 357 |
+
if multimask_output:
|
| 358 |
+
for i in range(masks_np.shape[0]):
|
| 359 |
+
detections.mask = masks_np[i][None, ...]
|
| 360 |
+
mask_list.append(masks_np[i])
|
| 361 |
+
result_list.append(get_annotated_image(mask_annotator,
|
| 362 |
+
detections,
|
| 363 |
+
img,
|
| 364 |
+
point_labels=point_labels,
|
| 365 |
+
point_coords=point_coords,
|
| 366 |
+
visualize_promts=visualize_promts,
|
| 367 |
+
pil=pil))
|
| 368 |
+
mask_result_list += get_masked_image(img,
|
| 369 |
+
detections.mask,
|
| 370 |
+
pil=pil)
|
| 371 |
+
else:
|
| 372 |
+
index = np.argmax(point_result.iou_predictions_np)
|
| 373 |
+
detections.mask = masks_np[index][None, ...]
|
| 374 |
+
mask_list.append(masks_np[index])
|
| 375 |
+
result_list.append(get_annotated_image(mask_annotator,
|
| 376 |
+
detections,
|
| 377 |
+
img,
|
| 378 |
+
point_labels=point_labels,
|
| 379 |
+
point_coords=point_coords,
|
| 380 |
+
visualize_promts=visualize_promts,
|
| 381 |
+
pil=pil))
|
| 382 |
+
mask_result_list += get_masked_image(img,
|
| 383 |
+
detections.mask,
|
| 384 |
+
pil=pil)
|
| 385 |
+
|
| 386 |
+
if boxes is not None:
|
| 387 |
+
result_masks = []
|
| 388 |
+
if src_shape is not None:
|
| 389 |
+
boxes_ = des_boxes
|
| 390 |
+
else:
|
| 391 |
+
boxes_ = boxes
|
| 392 |
+
if boxes_.shape[0] > 1:
|
| 393 |
+
for i in range(len(boxes)):
|
| 394 |
+
box_result = self.forward(sam_encoder_output,
|
| 395 |
+
features,
|
| 396 |
+
interm_features,
|
| 397 |
+
original_size,
|
| 398 |
+
input_size,
|
| 399 |
+
box=boxes_[i])
|
| 400 |
+
index = np.argmax(box_result.iou_predictions_np)
|
| 401 |
+
result_masks.append(box_result.masks_np[index])
|
| 402 |
+
mask = np.array(result_masks)
|
| 403 |
+
if src_shape is not None:
|
| 404 |
+
masks_np = interpolate_mask(mask, src_shape)
|
| 405 |
+
else:
|
| 406 |
+
masks_np = mask
|
| 407 |
+
mask_list.append(masks_np)
|
| 408 |
+
detections.mask = masks_np
|
| 409 |
+
result_list.append(get_annotated_image(mask_annotator,
|
| 410 |
+
detections,
|
| 411 |
+
img,
|
| 412 |
+
boxes=boxes,
|
| 413 |
+
visualize_promts=visualize_promts,
|
| 414 |
+
pil=pil))
|
| 415 |
+
mask_result_list += get_masked_image(img,
|
| 416 |
+
detections.mask,
|
| 417 |
+
pil=pil)
|
| 418 |
+
else:
|
| 419 |
+
box_result = self.forward(sam_encoder_output,
|
| 420 |
+
features,
|
| 421 |
+
interm_features,
|
| 422 |
+
original_size,
|
| 423 |
+
input_size,
|
| 424 |
+
box=boxes_)
|
| 425 |
+
if src_shape is not None:
|
| 426 |
+
masks_np = interpolate_mask(box_result.masks_np, src_shape)
|
| 427 |
+
else:
|
| 428 |
+
masks_np = box_result.masks_np
|
| 429 |
+
|
| 430 |
+
if multimask_output:
|
| 431 |
+
for i in range(masks_np.shape[0]):
|
| 432 |
+
detections.mask = masks_np[i][None, ...]
|
| 433 |
+
mask_list.append(masks_np[i])
|
| 434 |
+
result_list.append(get_annotated_image(mask_annotator,
|
| 435 |
+
detections,
|
| 436 |
+
img,
|
| 437 |
+
boxes=boxes,
|
| 438 |
+
visualize_promts=visualize_promts,
|
| 439 |
+
pil=pil))
|
| 440 |
+
mask_result_list += get_masked_image(img,
|
| 441 |
+
detections.mask,
|
| 442 |
+
pil=pil)
|
| 443 |
+
else:
|
| 444 |
+
index = np.argmax(box_result.iou_predictions_np)
|
| 445 |
+
detections.mask = masks_np[index][None, ...]
|
| 446 |
+
mask_list.append(masks_np[index])
|
| 447 |
+
result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, pil=pil))
|
| 448 |
+
mask_result_list += get_masked_image(img,
|
| 449 |
+
detections.mask,
|
| 450 |
+
pil=pil)
|
| 451 |
+
|
| 452 |
+
if texts is not None and grounding_dino_pipeline is not None:
|
| 453 |
+
detections = grounding_dino_pipeline(img[:, :, ::-1], texts, box_threshold, text_threshold)
|
| 454 |
+
box_annotator = sv.BoxAnnotator()
|
| 455 |
+
nms_idx = torchvision.ops.nms(
|
| 456 |
+
torch.from_numpy(detections.xyxy),
|
| 457 |
+
torch.from_numpy(detections.confidence),
|
| 458 |
+
nms_threshold
|
| 459 |
+
).numpy().tolist()
|
| 460 |
+
|
| 461 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
| 462 |
+
detections.confidence = detections.confidence[nms_idx]
|
| 463 |
+
detections.class_id = detections.class_id[nms_idx]
|
| 464 |
+
labels = [
|
| 465 |
+
f"{texts[class_id]} {confidence:0.2f}"
|
| 466 |
+
for _, _, confidence, class_id, _
|
| 467 |
+
in detections]
|
| 468 |
+
result_masks = []
|
| 469 |
+
if src_shape is not None:
|
| 470 |
+
_, boxes_ = transform_coords(src_shape, des_shape, boxes=detections.xyxy)
|
| 471 |
+
else:
|
| 472 |
+
boxes_ = detections.xyxy
|
| 473 |
+
for box in boxes_:
|
| 474 |
+
box_result = self.forward(sam_encoder_output,
|
| 475 |
+
features,
|
| 476 |
+
interm_features,
|
| 477 |
+
original_size,
|
| 478 |
+
input_size,
|
| 479 |
+
box=box)
|
| 480 |
+
index = np.argmax(box_result.iou_predictions_np)
|
| 481 |
+
result_masks.append(box_result.masks_np[index])
|
| 482 |
+
mask = np.array(result_masks)
|
| 483 |
+
if src_shape is not None:
|
| 484 |
+
detections.mask = interpolate_mask(mask, src_shape)
|
| 485 |
+
else:
|
| 486 |
+
detections.mask = mask
|
| 487 |
+
for i in range(detections.mask.shape[0]):
|
| 488 |
+
mask_list.append(detections.mask[i, ...])
|
| 489 |
+
if visualize_promts:
|
| 490 |
+
annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections)
|
| 491 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
| 492 |
+
else:
|
| 493 |
+
annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections)
|
| 494 |
+
|
| 495 |
+
if pil:
|
| 496 |
+
result_list.append(PIL.Image.fromarray(annotated_image[:, :, ::-1]))
|
| 497 |
+
else:
|
| 498 |
+
result_list.append(annotated_image[:, :, ::-1])
|
| 499 |
+
mask_result_list += get_masked_image(img,
|
| 500 |
+
detections.mask,
|
| 501 |
+
pil=pil)
|
| 502 |
+
|
| 503 |
+
return result_list, mask_result_list, mask_list
|
| 504 |
+
|
| 505 |
+
def predict(
|
| 506 |
+
self,
|
| 507 |
+
features: torch.Tensor,
|
| 508 |
+
interm_features: List[torch.Tensor],
|
| 509 |
+
original_size: Tuple,
|
| 510 |
+
input_size: Tuple,
|
| 511 |
+
point_coords: Optional[np.ndarray] = None,
|
| 512 |
+
point_labels: Optional[np.ndarray] = None,
|
| 513 |
+
box: Optional[np.ndarray] = None,
|
| 514 |
+
mask_input: Optional[np.ndarray] = None,
|
| 515 |
+
multimask_output: bool = True,
|
| 516 |
+
return_logits: bool = False,
|
| 517 |
+
hq_token_only: bool = False,
|
| 518 |
+
) -> SAMDecoderPredictOutput:
|
| 519 |
+
"""
|
| 520 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 521 |
+
|
| 522 |
+
Arguments:
|
| 523 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
| 524 |
+
model. Each point is in (X,Y) in pixels.
|
| 525 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
| 526 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 527 |
+
background point.
|
| 528 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
| 529 |
+
model, in XYXY format.
|
| 530 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 531 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
| 532 |
+
for SAM, H=W=256.
|
| 533 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 534 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 535 |
+
produce better masks than a single prediction. If only a single
|
| 536 |
+
mask is needed, the model's predicted quality score can be used
|
| 537 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 538 |
+
input prompts, multimask_output=False can give better results.
|
| 539 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 540 |
+
instead of a binary mask.
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
| 544 |
+
number of masks, and (H, W) is the original image size.
|
| 545 |
+
(np.ndarray): An array of length C containing the model's
|
| 546 |
+
predictions for the quality of each mask.
|
| 547 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
| 548 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
| 549 |
+
a subsequent iteration as mask input.
|
| 550 |
+
"""
|
| 551 |
+
# Transform input prompts
|
| 552 |
+
|
| 553 |
+
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
| 554 |
+
if point_coords is not None:
|
| 555 |
+
assert (
|
| 556 |
+
point_labels is not None
|
| 557 |
+
), "point_labels must be supplied if point_coords is supplied."
|
| 558 |
+
point_coords = self.transform.apply_coords(point_coords, original_size)
|
| 559 |
+
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
| 560 |
+
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
| 561 |
+
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
| 562 |
+
if box is not None:
|
| 563 |
+
box = self.transform.apply_boxes(box, original_size)
|
| 564 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
| 565 |
+
box_torch = box_torch[None, :]
|
| 566 |
+
if mask_input is not None:
|
| 567 |
+
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
| 568 |
+
mask_input_torch = mask_input_torch[None, :, :, :]
|
| 569 |
+
|
| 570 |
+
sam_decoder_predict_torch_output = self.predict_torch(
|
| 571 |
+
features,
|
| 572 |
+
interm_features,
|
| 573 |
+
original_size,
|
| 574 |
+
input_size,
|
| 575 |
+
coords_torch,
|
| 576 |
+
labels_torch,
|
| 577 |
+
box_torch,
|
| 578 |
+
mask_input_torch,
|
| 579 |
+
multimask_output,
|
| 580 |
+
return_logits=return_logits,
|
| 581 |
+
hq_token_only=hq_token_only,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
masks_np = sam_decoder_predict_torch_output.masks[0].detach().cpu().numpy()
|
| 585 |
+
iou_predictions_np = sam_decoder_predict_torch_output.iou_predictions[0].detach().cpu().numpy()
|
| 586 |
+
low_res_masks_np = sam_decoder_predict_torch_output.low_res_masks[0].detach().cpu().numpy()
|
| 587 |
+
return SAMDecoderPredictOutput(masks_np, iou_predictions_np, low_res_masks_np)
|
| 588 |
+
|
| 589 |
+
@torch.no_grad()
|
| 590 |
+
def predict_torch(
|
| 591 |
+
self,
|
| 592 |
+
features: torch.Tensor,
|
| 593 |
+
interm_features: List[torch.Tensor],
|
| 594 |
+
original_size: Tuple,
|
| 595 |
+
input_size: Tuple,
|
| 596 |
+
point_coords: Optional[torch.Tensor],
|
| 597 |
+
point_labels: Optional[torch.Tensor],
|
| 598 |
+
boxes: Optional[torch.Tensor] = None,
|
| 599 |
+
mask_input: Optional[torch.Tensor] = None,
|
| 600 |
+
multimask_output: bool = True,
|
| 601 |
+
return_logits: bool = False,
|
| 602 |
+
hq_token_only: bool = False,
|
| 603 |
+
) -> SAMDecoderPredictTorchOutput:
|
| 604 |
+
"""
|
| 605 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 606 |
+
Input prompts are batched torch tensors and are expected to already be
|
| 607 |
+
transformed to the input frame using ResizeLongestSide.
|
| 608 |
+
|
| 609 |
+
Arguments:
|
| 610 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
| 611 |
+
model. Each point is in (X,Y) in pixels.
|
| 612 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
| 613 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 614 |
+
background point.
|
| 615 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
| 616 |
+
model, in XYXY format.
|
| 617 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 618 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
| 619 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
| 620 |
+
predict method do not need further transformation.
|
| 621 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 622 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 623 |
+
produce better masks than a single prediction. If only a single
|
| 624 |
+
mask is needed, the model's predicted quality score can be used
|
| 625 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 626 |
+
input prompts, multimask_output=False can give better results.
|
| 627 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 628 |
+
instead of a binary mask.
|
| 629 |
+
|
| 630 |
+
Returns:
|
| 631 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
| 632 |
+
number of masks, and (H, W) is the original image size.
|
| 633 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
| 634 |
+
predictions for the quality of each mask.
|
| 635 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
| 636 |
+
of masks and H=W=256. These low res logits can be passed to
|
| 637 |
+
a subsequent iteration as mask input.
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
if point_coords is not None:
|
| 641 |
+
points = (point_coords, point_labels)
|
| 642 |
+
else:
|
| 643 |
+
points = None
|
| 644 |
+
|
| 645 |
+
# Embed prompts
|
| 646 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 647 |
+
points=points,
|
| 648 |
+
boxes=boxes,
|
| 649 |
+
masks=mask_input,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# Predict masks
|
| 653 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
| 654 |
+
image_embeddings=features,
|
| 655 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
| 656 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 657 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 658 |
+
multimask_output=multimask_output,
|
| 659 |
+
hq_token_only=hq_token_only,
|
| 660 |
+
interm_embeddings=interm_features,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Upscale the masks to the original image resolution
|
| 664 |
+
# masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
| 665 |
+
masks = self.postprocess_masks(low_res_masks, input_size, original_size)
|
| 666 |
+
|
| 667 |
+
if not return_logits:
|
| 668 |
+
masks = masks > self.mask_threshold
|
| 669 |
+
|
| 670 |
+
return SAMDecoderPredictTorchOutput(masks, iou_predictions, low_res_masks)
|
| 671 |
+
def forward(self,
|
| 672 |
+
sam_encoder_output: Optional[SAMEncoderOutput]=None,
|
| 673 |
+
features: Optional[torch.Tensor]=None,
|
| 674 |
+
interm_features: Optional[List[torch.Tensor]]=None,
|
| 675 |
+
original_size: Optional[Tuple]=None,
|
| 676 |
+
input_size: Optional[Tuple]=None,
|
| 677 |
+
point_coords: Optional[np.ndarray] = None,
|
| 678 |
+
point_labels: Optional[np.ndarray] = None,
|
| 679 |
+
box: Optional[np.ndarray] = None,
|
| 680 |
+
mask_input: Optional[np.ndarray] = None,
|
| 681 |
+
multimask_output: bool = True,
|
| 682 |
+
return_logits: bool = False,
|
| 683 |
+
hq_token_only: bool = False,
|
| 684 |
+
dino: bool = False
|
| 685 |
+
) -> SAMDecoderPredictOutput:
|
| 686 |
+
assert sam_encoder_output or (features is not None and original_size is not None and input_size is not None), 'one of sam_encoder_output and four necessary inputs must be given!'
|
| 687 |
+
if sam_encoder_output:
|
| 688 |
+
features = sam_encoder_output.features
|
| 689 |
+
interm_features = sam_encoder_output.interm_features
|
| 690 |
+
original_size = sam_encoder_output.original_size
|
| 691 |
+
input_size = sam_encoder_output.input_size
|
| 692 |
+
if self.adaptor is not None:
|
| 693 |
+
if dino:
|
| 694 |
+
features = F.interpolate(F.normalize(features, dim=1), size=(64, 64), mode='bilinear').permute(0, 2, 3, 1)
|
| 695 |
+
features = self.adaptor(features)
|
| 696 |
+
#
|
| 697 |
+
# else:
|
| 698 |
+
# features = self.adaptor(features, original_size)
|
| 699 |
+
|
| 700 |
+
return self.predict(features,
|
| 701 |
+
interm_features,
|
| 702 |
+
original_size,
|
| 703 |
+
input_size,
|
| 704 |
+
point_coords,
|
| 705 |
+
point_labels,
|
| 706 |
+
box,
|
| 707 |
+
mask_input,
|
| 708 |
+
multimask_output,
|
| 709 |
+
return_logits,
|
| 710 |
+
hq_token_only)
|
| 711 |
+
|
| 712 |
+
'''
|
| 713 |
+
class SAMPipeline(Pipeline):
|
| 714 |
+
@classmethod
|
| 715 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
| 716 |
+
sam_encoder_pipeline = SAMEncoderPipeline(ckpt_path, device, *args, **kwargs)
|
| 717 |
+
sam_decoder_pipeline = SAMDecoderPipeline(ckpt_path, device, *args, **kwargs)
|
| 718 |
+
pipeline = cls(**dict(sam_encoder_pipeline=sam_encoder_pipeline,
|
| 719 |
+
sam_decoder_pipeline=sam_decoder_pipeline,
|
| 720 |
+
device=device))
|
| 721 |
+
return pipeline
|
| 722 |
+
'''
|
sam_extension/utils/__init__.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import cv2
|
| 3 |
+
import PIL
|
| 4 |
+
import torch
|
| 5 |
+
from PIL.Image import Image
|
| 6 |
+
from typing import Union, Tuple, List, Optional
|
| 7 |
+
import numpy as np
|
| 8 |
+
import supervision as sv
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
|
| 11 |
+
# def add_points_tag(img: Union[Image, np.ndarray],
|
| 12 |
+
# point_labels: Union[List[int], np.ndarray] = None,
|
| 13 |
+
# point_coords: Union[List[List[int]], np.ndarray] = None,
|
| 14 |
+
# pil: bool = False):
|
| 15 |
+
# if point_labels is None or point_coords is None or \
|
| 16 |
+
# not isinstance(point_labels, (List, np.ndarray)) or \
|
| 17 |
+
# not isinstance(point_coords, (List, np.ndarray)):
|
| 18 |
+
# return img
|
| 19 |
+
# if len(point_labels) != len(point_coords):
|
| 20 |
+
# print('length of point_label and point_coordinate must be same!')
|
| 21 |
+
# return img
|
| 22 |
+
# if isinstance(img, Image):
|
| 23 |
+
# img = np.uint8(img)
|
| 24 |
+
# start_angle = 40
|
| 25 |
+
# x = 8
|
| 26 |
+
# y = 2
|
| 27 |
+
# def get_point(angle, d, base):
|
| 28 |
+
# angle = angle / 180.0 * math.pi
|
| 29 |
+
# _x, _y = math.cos(angle) * d, math.sin(angle) * d
|
| 30 |
+
# return [base[0] + _x, base[1] - _y]
|
| 31 |
+
# # assert len(point_labels) == len(point_coords), ''
|
| 32 |
+
# for i in range(len(point_labels)):
|
| 33 |
+
# points = []
|
| 34 |
+
# for j in range(5):
|
| 35 |
+
# _x, _y = math.cos(start_angle), math.sin(start_angle)
|
| 36 |
+
# points.append(get_point(start_angle, x, point_coords[i]))
|
| 37 |
+
# start_angle -= 36
|
| 38 |
+
# points.append(get_point(start_angle, y, point_coords[i]))
|
| 39 |
+
# start_angle -= 36
|
| 40 |
+
# points = np.array([points], np.int32)
|
| 41 |
+
# color = (255, 0, 0) if point_labels[i] == 0 else (0, 255, 0)
|
| 42 |
+
# cv2.fillPoly(img, points, color, cv2.LINE_AA)
|
| 43 |
+
# if pil:
|
| 44 |
+
# img = PIL.Image.fromarray(img)
|
| 45 |
+
# return img
|
| 46 |
+
def add_points_tag(img: Union[Image, np.ndarray],
|
| 47 |
+
point_labels: Union[List[int], np.ndarray] = None,
|
| 48 |
+
point_coords: Union[List[List[int]], np.ndarray] = None,
|
| 49 |
+
pil: bool = False):
|
| 50 |
+
if point_labels is None or point_coords is None or \
|
| 51 |
+
not isinstance(point_labels, (List, np.ndarray)) or \
|
| 52 |
+
not isinstance(point_coords, (List, np.ndarray)):
|
| 53 |
+
return img
|
| 54 |
+
if len(point_labels) != len(point_coords):
|
| 55 |
+
print('length of point_label and point_coordinate must be same!')
|
| 56 |
+
return img
|
| 57 |
+
if isinstance(img, Image):
|
| 58 |
+
img = np.array(img)
|
| 59 |
+
# img.flags.writeable = True
|
| 60 |
+
h, w = img.shape[:2]
|
| 61 |
+
x_start_list, x_end_list = np.where((point_coords[:, 0] - 4) > 0, point_coords[:, 0] - 4, 0), np.where((point_coords[:, 0] + 4) < w, point_coords[:, 0] + 4, w)
|
| 62 |
+
y_start_list, y_end_list = np.where((point_coords[:, 1] - 4) > 0, point_coords[:, 1] - 4, 0), np.where((point_coords[:, 1] + 4) < h, point_coords[:, 1] + 4, h)
|
| 63 |
+
for i in range(len(point_labels)):
|
| 64 |
+
x_start, x_end = x_start_list[i], x_end_list[i]
|
| 65 |
+
y_start, y_end = y_start_list[i], y_end_list[i]
|
| 66 |
+
label = point_labels[i]
|
| 67 |
+
color = [0, 255, 0] if int(label) == 1 else [255, 0, 0]
|
| 68 |
+
for x in range(x_start, x_end):
|
| 69 |
+
for y in range(y_start, y_end):
|
| 70 |
+
img[y, x, :] = color
|
| 71 |
+
if pil:
|
| 72 |
+
img = PIL.Image.fromarray(img)
|
| 73 |
+
return img
|
| 74 |
+
def add_boxes_tag(img: Union[Image, np.ndarray],
|
| 75 |
+
boxes: Union[List[List[int]], np.ndarray] = None,
|
| 76 |
+
pil: bool = False):
|
| 77 |
+
if boxes is None or not isinstance(boxes, (List, np.ndarray)):
|
| 78 |
+
return img
|
| 79 |
+
# if isinstance(boxes, np.ndarray):
|
| 80 |
+
# if not boxes.all():
|
| 81 |
+
# return img
|
| 82 |
+
# else:
|
| 83 |
+
# if not boxes:
|
| 84 |
+
# return img
|
| 85 |
+
if isinstance(img, Image):
|
| 86 |
+
img = np.uint8(img)
|
| 87 |
+
thickness = 2
|
| 88 |
+
for i in range(len(boxes)):
|
| 89 |
+
color = (0, 255, 0)
|
| 90 |
+
img = cv2.rectangle(img, (boxes[i][0], boxes[i][1]), (boxes[i][2], boxes[i][3]), color, thickness)
|
| 91 |
+
if pil:
|
| 92 |
+
img = PIL.Image.fromarray(img)
|
| 93 |
+
return img
|
| 94 |
+
|
| 95 |
+
def add_prompts_tag(img: Union[Image, np.ndarray],
|
| 96 |
+
point_labels: Union[List[int], np.ndarray] = None,
|
| 97 |
+
point_coords: Union[List[List[int]], np.ndarray] = None,
|
| 98 |
+
boxes: Union[List[List[int]], np.ndarray] = None,
|
| 99 |
+
pil: bool = False):
|
| 100 |
+
img = add_points_tag(img, point_labels, point_coords, pil=pil)
|
| 101 |
+
img = add_boxes_tag(img, boxes, pil=pil)
|
| 102 |
+
return img
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_empty_detections():
|
| 106 |
+
detections = sv.Detections(xyxy=np.array([0, 0, 0, 0]).reshape(1, 4))
|
| 107 |
+
detections.xyxy = None
|
| 108 |
+
return detections
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def pca_feature(feature: torch.Tensor, dim: int = 3, return_np: bool = True):
|
| 112 |
+
pca = PCA(n_components=dim)
|
| 113 |
+
H, W, C = feature.shape
|
| 114 |
+
feature = feature.view(-1, C).cpu().numpy()
|
| 115 |
+
feature = pca.fit_transform(feature)
|
| 116 |
+
feature = torch.tensor(feature.reshape(H, W, dim))
|
| 117 |
+
if return_np:
|
| 118 |
+
return feature.numpy()
|
| 119 |
+
else:
|
| 120 |
+
return feature
|
| 121 |
+
|
| 122 |
+
def visual_feature_rgb(feature: torch.Tensor, pil:bool = True):
|
| 123 |
+
assert feature.ndim >= 3, 'the dim of feature must >= 3!'
|
| 124 |
+
if feature.ndim == 4:
|
| 125 |
+
feature = feature.squeeze(0)
|
| 126 |
+
if feature.shape[-1] != 3:
|
| 127 |
+
feature = pca_feature(feature, 3, False)
|
| 128 |
+
max_f, _ = feature.max(-1)
|
| 129 |
+
min_f, _ = feature.min(-1)
|
| 130 |
+
feature = (feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None])
|
| 131 |
+
feature = np.uint8((feature*255).cpu().numpy())
|
| 132 |
+
if pil:
|
| 133 |
+
return PIL.Image.fromarray(feature)
|
| 134 |
+
else:
|
| 135 |
+
return feature
|
| 136 |
+
|
| 137 |
+
def transform_coords(src_shape, des_shape, points = None, boxes = None):
|
| 138 |
+
assert points is not None or boxes is not None, 'one of points and boxes must be given!'
|
| 139 |
+
scale_h = des_shape[0] / src_shape[0]
|
| 140 |
+
scale_w = des_shape[1] / src_shape[1]
|
| 141 |
+
if points is not None:
|
| 142 |
+
new_points = np.full_like(points, 0)
|
| 143 |
+
new_points[:, 0] = points[:, 0] * scale_w
|
| 144 |
+
new_points[:, 1] = points[:, 1] * scale_h
|
| 145 |
+
new_points.astype(np.int64)
|
| 146 |
+
else:
|
| 147 |
+
new_points = None
|
| 148 |
+
if boxes is not None:
|
| 149 |
+
new_boxes = np.full_like(boxes, 0)
|
| 150 |
+
new_boxes[:, 0] = boxes[:, 0] * scale_w
|
| 151 |
+
new_boxes[:, 1] = boxes[:, 1] * scale_h
|
| 152 |
+
new_boxes[:, 2] = boxes[:, 2] * scale_w
|
| 153 |
+
new_boxes[:, 3] = boxes[:, 3] * scale_h
|
| 154 |
+
new_boxes.astype(np.int64)
|
| 155 |
+
else:
|
| 156 |
+
new_boxes = None
|
| 157 |
+
return new_points, new_boxes
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def mask2greyimg(mask_list, pil=True):
|
| 161 |
+
grey_img_list = []
|
| 162 |
+
for mask in mask_list:
|
| 163 |
+
if pil:
|
| 164 |
+
grey_img_list.append(PIL.Image.fromarray(np.uint8(mask*255)))
|
| 165 |
+
else:
|
| 166 |
+
grey_img_list.append(np.uint8(mask * 255))
|
| 167 |
+
return grey_img_list
|
| 168 |
+
if __name__ == '__main__':
|
| 169 |
+
src_shape = (100,100)
|
| 170 |
+
des_shape = (200,200)
|
| 171 |
+
points = np.array([[20,20],[40,40]])
|
| 172 |
+
boxes = np.array([[10,10,20,20]])
|
| 173 |
+
new_points, new_boxes = transform_coords(src_shape, des_shape, points, boxes)
|
| 174 |
+
print(new_points, new_boxes)
|
| 175 |
+
|
sam_extension/utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (4.51 kB). View file
|
|
|