Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from collections import OrderedDict | |
| from typing import Tuple, Union | |
| import logging | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from timm.models.layers import DropPath, trunc_normal_ | |
| from .backbone import Backbone | |
| from .build import BACKBONE_REGISTRY | |
| from .det_swin import SwinTransformer | |
| from ..text_encoder import build_text_encoder | |
| from ..text_encoder import build_tokenizer | |
| class LayerNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-12): | |
| """Construct a layernorm module in the TF style (epsilon inside the square root). | |
| """ | |
| super(LayerNorm, self).__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, x): | |
| pdtype = x.dtype | |
| x = x.float() | |
| u = x.mean(-1, keepdim=True) | |
| s = (x - u).pow(2).mean(-1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
| return self.weight * x.to(pdtype) + self.bias | |
| class QuickGELU(nn.Module): | |
| def forward(self, x: torch.Tensor): | |
| return x * torch.sigmoid(1.702 * x) | |
| class ResidualAttentionBlock(nn.Module): | |
| def __init__(self, | |
| d_model: int, | |
| n_head: int, | |
| attn_mask: torch.Tensor = None, | |
| drop_path: float = 0.0): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(d_model, n_head) | |
| self.ln_1 = LayerNorm(d_model) | |
| self.mlp = nn.Sequential(OrderedDict([ | |
| ("c_fc", nn.Linear(d_model, d_model * 4)), | |
| ("gelu", QuickGELU()), | |
| ("c_proj", nn.Linear(d_model * 4, d_model)) | |
| ])) | |
| self.ln_2 = LayerNorm(d_model) | |
| self.attn_mask = attn_mask | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| def attention(self, x: torch.Tensor): | |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ | |
| if self.attn_mask is not None else None | |
| return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |
| def forward(self, x: torch.Tensor): | |
| x = x + self.drop_path(self.attention(self.ln_1(x))) | |
| x = x + self.drop_path(self.mlp(self.ln_2(x))) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__(self, | |
| context_length: int, | |
| vocab_size: int, | |
| width: int, | |
| layers: int, | |
| heads: int, | |
| drop_path: float = 0.0): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(vocab_size, width) | |
| self.context_length = context_length | |
| self.positional_embedding = nn.Parameter( | |
| torch.empty(self.context_length, width) | |
| ) | |
| self.width = width | |
| self.layers = layers | |
| attn_mask = self.build_attention_mask() | |
| dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule | |
| self.resblocks = nn.Sequential( | |
| *[ | |
| ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) | |
| for i in range(layers) | |
| ] | |
| ) | |
| self.ln_final = LayerNorm(width) | |
| trunc_normal_(self.positional_embedding, std=.02) | |
| # nn.init.normal_(self.token_embedding, std=.02) | |
| trunc_normal_(self.token_embedding.weight, std=.02) | |
| self.apply(self._init_weights) | |
| def build_attention_mask(self): | |
| # lazily create causal attention mask, with full attention between the vision tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(self.context_length, self.context_length) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Linear, nn.Conv2d)): | |
| trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): | |
| nn.init.constant_(m.bias, 0) | |
| def no_weight_decay(self): | |
| return { | |
| 'positional_embedding', | |
| 'token_embedding', | |
| } | |
| def forward(self, text: torch.Tensor): | |
| x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | |
| x = x + self.positional_embedding | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.resblocks(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.ln_final(x) | |
| x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] | |
| return x | |
| class CLIP(Backbone): | |
| def __init__(self, config: dict): | |
| super().__init__() | |
| spec_text = config['MODEL']['SPEC']['TEXT'] | |
| assert spec_text['TOKENIZER'] == 'clip', 'Only support clip tokenizer' | |
| self.tokenizer_style = spec_text['TOKENIZER'] | |
| self.tokenizer = build_tokenizer(spec_text) | |
| self.text_encoder = build_text_encoder(spec_text, self.tokenizer, True) | |
| embed_dim = config['MODEL']['SPEC']['EMBED_DIM'] | |
| self.text_projection = nn.Parameter( | |
| torch.empty(spec_text['WIDTH'], embed_dim) | |
| ) | |
| spec_vision = config['MODEL']['SPEC']['VISION'] | |
| self.image_encoder = SwinTransformer( | |
| patch_size=spec_vision['PATCH_SIZE'], | |
| in_chans=spec_vision['IN_CHANS'], | |
| embed_dim=spec_vision['EMBED_DIM'], | |
| depths=spec_vision['DEPTHS'], | |
| num_heads=spec_vision['NUM_HEADS'], | |
| window_size=spec_vision['WINDOW_SIZE'], | |
| mlp_ratio=spec_vision['MLP_RATIO'], | |
| qkv_bias=spec_vision['QKV_BIAS'], | |
| qk_scale=spec_vision.get('QK_SCALE', None), | |
| drop_rate=spec_vision['DROP_RATE'], | |
| attn_drop_rate=spec_vision['ATTN_DROP_RATE'], | |
| drop_path_rate=spec_vision['DROP_PATH_RATE'], | |
| ape=spec_vision['APE'], | |
| patch_norm=spec_vision['PATCH_NORM'], | |
| out_indices=(0, 1, 2, 3), | |
| frozen_stages=-1, | |
| use_checkpoint=False, | |
| ) | |
| width = spec_vision['EMBED_DIM'] * 2 ** (len(spec_vision['DEPTHS']) - 1) | |
| self.image_projection = nn.Parameter( | |
| torch.empty(width, embed_dim) | |
| ) | |
| # self.logit_scale = nn.Parameter(torch.FloatTensor([np.log(1 / 0.07)])) | |
| self.logit_scale = nn.Parameter(torch.ones([])) | |
| trunc_normal_(self.text_projection, std=.02) | |
| trunc_normal_(self.image_projection, std=.02) | |
| def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): | |
| if os.path.isfile(pretrained): | |
| pretrained_dict = torch.load(pretrained, map_location='cpu') | |
| logger.info(f'=> loading pretrained model {pretrained}') | |
| model_dict = self.state_dict() | |
| pretrained_dict = { | |
| k: v for k, v in pretrained_dict.items() | |
| if k in model_dict.keys() | |
| } | |
| need_init_state_dict = {} | |
| for k, v in pretrained_dict.items(): | |
| need_init = ( | |
| k.split('.')[0] in pretrained_layers | |
| or pretrained_layers[0] is '*' | |
| ) | |
| if need_init: | |
| if verbose: | |
| logging.info(f'=> init {k} from {pretrained}') | |
| need_init_state_dict[k] = v | |
| self.load_state_dict(need_init_state_dict, strict=False) | |
| def no_weight_decay(self): | |
| no_weight_decay = {'logit_scale'} | |
| for k in self.text_encoder.no_weight_decay(): | |
| no_weight_decay.add('text.'+k) | |
| for k in self.image_encoder.no_weight_decay(): | |
| no_weight_decay.add('visual.'+k) | |
| return no_weight_decay | |
| def no_weight_decay_keywords(self): | |
| return {'relative_position_bias_table'} | |
| def dtype(self): | |
| return self.image_encoder.conv1.weight.dtype | |
| def encode_image(self, image, norm=True): | |
| x = self.image_encoder(image) | |
| return x | |
| def encode_text(self, text, norm=True): | |
| assert isinstance(text, str), "only support single query" | |
| tokens = self.tokenizer( | |
| text, padding='max_length', truncation=True, max_length=77, return_tensors='pt' | |
| ) | |
| tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} | |
| x = self.text_encoder(**tokens) | |
| x = x['last_hidden_state'] | |
| x = x[torch.arange(x.size(0)), tokens['input_ids'].argmax(dim=-1)] | |
| x = x @ self.text_projection | |
| if norm: | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| return x | |
| def forward(self, image): | |
| features_image = self.image_encoder(image) | |
| return features_image | |
| def build_clip_swin_backbone(cfg, input_shape): | |
| """ | |
| Create a CLIP Swin instance from config. | |
| Returns: | |
| SwinTransformer: a :class:`SwinTransformer` instance. | |
| """ | |
| spec_vision = cfg.MODEL.CLIP.VISION | |
| return SwinTransformer( | |
| patch_size=spec_vision['PATCH_SIZE'], | |
| in_chans=spec_vision['IN_CHANS'], | |
| embed_dim=spec_vision['EMBED_DIM'], | |
| depths=spec_vision['DEPTHS'], | |
| num_heads=spec_vision['NUM_HEADS'], | |
| window_size=spec_vision['WINDOW_SIZE'], | |
| mlp_ratio=spec_vision['MLP_RATIO'], | |
| qkv_bias=spec_vision['QKV_BIAS'], | |
| qk_scale=spec_vision.get('QK_SCALE', None), | |
| drop_rate=spec_vision['DROP_RATE'], | |
| attn_drop_rate=spec_vision['ATTN_DROP_RATE'], | |
| drop_path_rate=spec_vision['DROP_PATH_RATE'], | |
| ape=spec_vision['APE'], | |
| patch_norm=spec_vision['PATCH_NORM'], | |
| out_indices=(0, 1, 2, 3), | |
| frozen_stages=-1, | |
| use_checkpoint=False, | |
| ) | |
| def build_clip_swin(cfg, input_shape): | |
| """ | |
| Create a CLIP Swin instance from config. | |
| Returns: | |
| SwinTransformer: a :class:`SwinTransformer` instance. | |
| """ | |
| return CLIP(cfg) | 
