""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats. """ from typing import Union import torch import numpy as np from .model import CLIP, CustomTextCLIP from .transformer import TextTransformer, Transformer @torch.no_grad() def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): """ Load weights from .npz checkpoints for official Google big_vision image-text models Currently the SigLIP source models are supported and a CustomTextCLIP destination model w/ timm image encoder. """ from timm.layers import resample_patch_embed, resample_abs_pos_embed def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) interpolation = 'bilinear' antialias = False def _convert_timm_img(module, prefix): embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: embed_conv_w = resample_patch_embed( embed_conv_w, module.patch_embed.proj.weight.shape[-2:], interpolation=interpolation, antialias=antialias, verbose=True, ) module.patch_embed.proj.weight.copy_(embed_conv_w) module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) if module.cls_token is not None: module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) if pos_embed_w.shape != module.pos_embed.shape: assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, new_size=module.patch_embed.grid_size, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) module.pos_embed.copy_(pos_embed_w) mha_sub, b_sub, ln1_sub = (0, 0, 1) for i, block in enumerate(module.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if module.attn_pool is not None: block_prefix = f'{prefix}MAPHead_0/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) module.attn_pool.kv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) module.attn_pool.kv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) for r in range(2): getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) def _convert_openclip_transformer(module: Transformer, prefix): for i, block in enumerate(module.resblocks.children()): block_prefix = f'{prefix}encoderblock_{i}/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.in_proj_weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.in_proj_bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) def _convert_openclip_txt(module: TextTransformer, prefix): module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) module.positional_embedding.copy_(pos_embed_w) _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) _convert_timm_img(model.visual.trunk, 'params/img/') _convert_openclip_txt(model.text, 'params/txt/') model.logit_bias.copy_(_n2p(w['params/b'])[0]) model.logit_scale.copy_(_n2p(w['params/t'])[0]) @torch.no_grad() def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): def _convert_timm_img(state_dict): if fastvit: from timm.models.fastvit import checkpoint_filter_fn else: from timm.models.vision_transformer_hybrid import checkpoint_filter_fn timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} return timm_state_dict def _convert_openclip_txt(state_dict, prefix='text_encoder.'): text_dict = {} for k, v in state_dict.items(): if not k.startswith(prefix): continue k = k.replace(prefix, '') k = k.replace('projection_layer', 'text_projection') k = k.replace('embedding_layer', 'token_embedding') if k.startswith('positional_embedding.pos_embed.pos_embed'): k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') v = v.squeeze() k = k.replace('final_layer_norm', 'ln_final') k = k.replace('pre_norm_mha.0', 'ln_1') k = k.replace('pre_norm_mha.1', 'attn') k = k.replace('pre_norm_ffn.0', 'ln_2') k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') k = k.replace('qkv_proj.weight', 'in_proj_weight') k = k.replace('qkv_proj.bias', 'in_proj_bias') k = k.replace('transformer.', 'transformer.resblocks.') text_dict['text.' + k] = v return text_dict image_dict = _convert_timm_img(state_dict) text_dict = _convert_openclip_txt(state_dict) out_dict = {**image_dict, **text_dict} out_dict['logit_scale'] = state_dict['logit_scale'] return out_dict def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) state_dict = convert_mobile_clip_state_dict(model, state_dict) if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: # convert b model state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) return state_dict