Spaces:
Running
Running
| import os | |
| import functools | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from typing import Optional, List, Union, Tuple, Type | |
| from segment_anything import build_sam | |
| from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT | |
| from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer | |
| from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention | |
| from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam | |
| from segment_anything.modeling.sam import Sam | |
| from sam_extension.distillation_models.fastertinyvit import FasterTinyViT | |
| from sam_extension.distillation_models.dino import DINO | |
| # from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer | |
| SAM_REPO_ID = 'YouLiXiya/YL-SAM' | |
| hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True) | |
| class SAMImageEncoder(nn.Module): | |
| def __init__(self, | |
| sam_checkpoint_path, | |
| device='cuda'): | |
| super(SAMImageEncoder, self).__init__() | |
| sam = build_sam(sam_checkpoint_path).to(device) | |
| self.image_encoder = sam.image_encoder | |
| del sam | |
| torch.cuda.empty_cache() | |
| def forward(self, x): | |
| return self.image_encoder(x) | |
| class MobileSAMImageEncoder(nn.Module): | |
| def __init__(self, | |
| sam_checkpoint_path, | |
| device='cuda'): | |
| super(MobileSAMImageEncoder, self).__init__() | |
| sam = load_mobile_sam(sam_checkpoint_path, device) | |
| self.image_encoder = sam.image_encoder | |
| del sam | |
| torch.cuda.empty_cache() | |
| def forward(self, x): | |
| return self.image_encoder(x) | |
| class SAMEncoderViT(nn.Module): | |
| def __init__( | |
| self, | |
| img_size: int = 1024, | |
| patch_size: int = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 12, | |
| mlp_ratio: float = 4.0, | |
| out_chans: int = 256, | |
| qkv_bias: bool = True, | |
| norm_layer: Type[nn.Module] = nn.LayerNorm, | |
| act_layer: Type[nn.Module] = nn.GELU, | |
| use_abs_pos: bool = True, | |
| use_rel_pos: bool = False, | |
| rel_pos_zero_init: bool = True, | |
| window_size: int = 0, | |
| global_attn_indexes: Tuple[int, ...] = (), | |
| multi_scale: bool = False, | |
| output_shape: Union[Tuple, List] = None | |
| ) -> None: | |
| """ | |
| Args: | |
| img_size (int): Input image size. | |
| patch_size (int): Patch size. | |
| in_chans (int): Number of input image channels. | |
| embed_dim (int): Patch embedding dimension. | |
| depth (int): Depth of ViT. | |
| num_heads (int): Number of attention heads in each ViT block. | |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
| norm_layer (nn.Module): Normalization layer. | |
| act_layer (nn.Module): Activation layer. | |
| use_abs_pos (bool): If True, use absolute positional embeddings. | |
| use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
| rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
| window_size (int): Window size for window attention blocks. | |
| global_attn_indexes (list): Indexes for blocks using global attention. | |
| """ | |
| super().__init__() | |
| self.img_size = img_size | |
| self.multi_scale = multi_scale | |
| self.output_shape = tuple(output_shape) if output_shape else None | |
| self.patch_embed = PatchEmbed( | |
| kernel_size=(patch_size, patch_size), | |
| stride=(patch_size, patch_size), | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| ) | |
| self.pos_embed: Optional[nn.Parameter] = None | |
| if use_abs_pos: | |
| # Initialize absolute positional embedding with pretrain image size. | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) | |
| ) | |
| self.blocks = nn.ModuleList() | |
| for i in range(depth): | |
| block = Block( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| use_rel_pos=use_rel_pos, | |
| rel_pos_zero_init=rel_pos_zero_init, | |
| window_size=window_size if i not in global_attn_indexes else 0, | |
| input_size=(img_size // patch_size, img_size // patch_size), | |
| ) | |
| self.blocks.append(block) | |
| self.neck = nn.Sequential( | |
| nn.Conv2d( | |
| embed_dim*depth if self.multi_scale and self.output_shape else embed_dim, | |
| out_chans, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_chans), | |
| nn.Conv2d( | |
| out_chans, | |
| out_chans, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_chans), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.patch_embed(x) | |
| if self.pos_embed is not None: | |
| x = x + self.pos_embed | |
| if self.multi_scale and self.output_shape: | |
| output_list = [] | |
| for blk in self.blocks: | |
| x = blk(x) | |
| output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear')) | |
| x = self.neck(torch.cat(output_list, dim=1)) | |
| else: | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.neck(x.permute(0, 3, 1, 2)) | |
| return x | |
| class SAMEncoderAdaptor(nn.Module): | |
| def __init__(self, | |
| img_size: int, | |
| input_size: Optional[Tuple[int, int]], | |
| embed_dim: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 12, | |
| mlp_ratio: float = 4.0, | |
| out_chans: int = 256, | |
| qkv_bias: bool = True, | |
| norm_layer: Type[nn.Module] = nn.LayerNorm, | |
| act_layer: Type[nn.Module] = nn.GELU, | |
| use_abs_pos: bool = True, | |
| use_rel_pos: bool = False, | |
| rel_pos_zero_init: bool = True, | |
| window_size: int = 0, | |
| global_attn_indexes: Tuple[int, ...] = (), | |
| multi_scale: bool = False, | |
| output_shape: Union[Tuple, List] = None): | |
| super(SAMEncoderAdaptor, self).__init__() | |
| self.img_size = img_size | |
| self.multi_scale = multi_scale | |
| self.output_shape = tuple(output_shape) if output_shape else None | |
| self.pos_embed: Optional[nn.Parameter] = None | |
| if use_abs_pos: | |
| # Initialize absolute positional embedding with pretrain image size. | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros(1, input_size[0], input_size[1], embed_dim) | |
| ) | |
| self.blocks = nn.ModuleList() | |
| for i in range(depth): | |
| block = Block( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| use_rel_pos=use_rel_pos, | |
| rel_pos_zero_init=rel_pos_zero_init, | |
| window_size=window_size if i not in global_attn_indexes else 0, | |
| input_size=input_size, | |
| ) | |
| self.blocks.append(block) | |
| self.neck = nn.Sequential( | |
| nn.Conv2d( | |
| embed_dim * depth if self.multi_scale and self.output_shape else embed_dim, | |
| out_chans, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_chans), | |
| nn.Conv2d( | |
| out_chans, | |
| out_chans, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_chans), | |
| ) | |
| def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor: | |
| if original_size: | |
| original_size = torch.LongTensor(original_size) | |
| output_shape = x.shape[-2:] | |
| if original_size.ndim == 1: | |
| original_size = original_size[None, ...] | |
| adaptor_inputs = [] | |
| for i in range(original_size.shape[0]): | |
| h, w = original_size[i] | |
| if h > w: | |
| new_h = output_shape[0] | |
| new_w = int(w * new_h / h) | |
| else: | |
| new_w = output_shape[1] | |
| new_h = int(h * new_w / w) | |
| encoder_output = x[0].unsqueeze(0) | |
| encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear') | |
| pad_h = output_shape[0] - new_h | |
| pad_w = output_shape[1] - new_w | |
| encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h)) | |
| adaptor_inputs.append(encoder_output) | |
| adaptor_inputs = torch.cat(adaptor_inputs, dim=0) | |
| x = adaptor_inputs.permute(0, 2, 3, 1) | |
| if self.pos_embed is not None: | |
| x = x + self.pos_embed | |
| if self.multi_scale and self.output_shape: | |
| output_list = [] | |
| for blk in self.blocks: | |
| x = blk(x) | |
| output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear')) | |
| x = self.neck(torch.cat(output_list, dim=1)) | |
| else: | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.neck(x.permute(0, 3, 1, 2)) | |
| return x | |
| class DINOSAMViT(nn.Module): | |
| def __init__(self, | |
| dino_model_type, | |
| device='cuda', | |
| pca_dim=None, | |
| **kwargs | |
| ): | |
| super(DINOSAMViT, self).__init__() | |
| self.img_size = kwargs['img_size'] | |
| if not pca_dim: | |
| pca_dim = None | |
| self.dino = DINO(dino_model_type, device, self.img_size, pca_dim) | |
| self.input_size = tuple(kwargs['output_shape']) | |
| # input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size | |
| # self.input_size = (input_size, input_size) | |
| embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim | |
| kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim}) | |
| self.adaptor = SAMEncoderAdaptor(**kwargs).to(device) | |
| def extract_dino_features(self, x, transform=False, size = None): | |
| return self.dino.extract_features(x, transform, size) | |
| def forward(self, x, transform=False, size = None): | |
| dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3) | |
| adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1) | |
| return self.adaptor(adaptor_input) | |
| def setup_model(model_config): | |
| prompt_embed_dim = 256 | |
| image_size = 1024 | |
| vit_patch_size = 16 | |
| image_embedding_size = image_size // vit_patch_size | |
| model = eval(model_config.pop('type'))(**model_config) | |
| if model.__class__.__name__ == 'SAMEncoderAdaptor': | |
| adaptor = model | |
| image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder | |
| else: | |
| adaptor = None | |
| image_encoder = model | |
| sam = Sam( | |
| image_encoder=image_encoder, | |
| prompt_encoder=PromptEncoder( | |
| embed_dim=prompt_embed_dim, | |
| image_embedding_size=(image_embedding_size, image_embedding_size), | |
| input_image_size=(image_size, image_size), | |
| mask_in_chans=16, | |
| ), | |
| mask_decoder=MaskDecoder( | |
| num_multimask_outputs=3, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| ), | |
| adaptor=adaptor, | |
| pixel_mean=[123.675, 116.28, 103.53], | |
| pixel_std=[58.395, 57.12, 57.375], | |
| ) | |
| return sam | |
| def load_distillation_sam(distillation_sam_ckpt_path, | |
| device='cuda'): | |
| ckpt = torch.load(distillation_sam_ckpt_path) | |
| sam = setup_model(ckpt['model_config']) | |
| sam.load_state_dict(ckpt['model']) | |
| return sam.to(device) | |
| def load_sam(sam_ckpt_path, sam_version, device): | |
| if not os.path.exists(sam_ckpt_path): | |
| parent_dir = os.path.dirname(sam_ckpt_path) | |
| os.makedirs(parent_dir, exist_ok=True) | |
| hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir) | |
| if sam_version == 'sam': | |
| sam = build_sam(sam_ckpt_path).to(device) | |
| elif sam_version == 'mobile_sam': | |
| sam = load_mobile_sam(sam_ckpt_path, device) | |
| elif sam_version == 'distillation_sam': | |
| sam = load_distillation_sam(sam_ckpt_path, device) | |
| else: | |
| raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]') | |
| return sam | |
| if __name__ == '__main__': | |
| from distillation.utils import get_parameter_number | |
| vit = SAMEncoderViT(depth=3, | |
| embed_dim=256, | |
| img_size=512, | |
| mlp_ratio=4, | |
| num_heads=16, | |
| patch_size=8, | |
| qkv_bias=True, | |
| use_rel_pos=True, | |
| global_attn_indexes=[1], | |
| window_size=16, | |
| out_chans=256, | |
| multi_scale=False, | |
| output_shape='').cuda() | |
| x = torch.randn((1, 3, 512, 512)).cuda() | |
| print(vit(x).shape) | |
| print(get_parameter_number(vit)) | |