|
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats. |
|
""" |
|
from typing import Union |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from .model import CLIP, CustomTextCLIP |
|
from .transformer import TextTransformer, Transformer |
|
|
|
|
|
@torch.no_grad() |
|
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): |
|
""" Load weights from .npz checkpoints for official Google big_vision image-text models |
|
|
|
Currently the SigLIP source models are supported and a CustomTextCLIP destination model |
|
w/ timm image encoder. |
|
""" |
|
from timm.layers import resample_patch_embed, resample_abs_pos_embed |
|
|
|
def _n2p(w, t=True): |
|
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: |
|
w = w.flatten() |
|
if t: |
|
if w.ndim == 4: |
|
w = w.transpose([3, 2, 0, 1]) |
|
elif w.ndim == 3: |
|
w = w.transpose([2, 0, 1]) |
|
elif w.ndim == 2: |
|
w = w.transpose([1, 0]) |
|
return torch.from_numpy(w) |
|
|
|
w = np.load(checkpoint_path) |
|
interpolation = 'bilinear' |
|
antialias = False |
|
|
|
def _convert_timm_img(module, prefix): |
|
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) |
|
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: |
|
embed_conv_w = resample_patch_embed( |
|
embed_conv_w, |
|
module.patch_embed.proj.weight.shape[-2:], |
|
interpolation=interpolation, |
|
antialias=antialias, |
|
verbose=True, |
|
) |
|
module.patch_embed.proj.weight.copy_(embed_conv_w) |
|
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) |
|
|
|
if module.cls_token is not None: |
|
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) |
|
|
|
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) |
|
if pos_embed_w.shape != module.pos_embed.shape: |
|
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' |
|
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) |
|
pos_embed_w = resample_abs_pos_embed( |
|
pos_embed_w, |
|
new_size=module.patch_embed.grid_size, |
|
num_prefix_tokens=num_prefix_tokens, |
|
interpolation=interpolation, |
|
antialias=antialias, |
|
verbose=True, |
|
) |
|
module.pos_embed.copy_(pos_embed_w) |
|
|
|
mha_sub, b_sub, ln1_sub = (0, 0, 1) |
|
for i, block in enumerate(module.blocks.children()): |
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' |
|
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' |
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
|
block.attn.qkv.weight.copy_(torch.cat([ |
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
|
block.attn.qkv.bias.copy_(torch.cat([ |
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
|
for r in range(2): |
|
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) |
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) |
|
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) |
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) |
|
|
|
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) |
|
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) |
|
|
|
if module.attn_pool is not None: |
|
block_prefix = f'{prefix}MAPHead_0/' |
|
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' |
|
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) |
|
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) |
|
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) |
|
module.attn_pool.kv.weight.copy_(torch.cat([ |
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) |
|
module.attn_pool.kv.bias.copy_(torch.cat([ |
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) |
|
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
|
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
|
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
|
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
|
for r in range(2): |
|
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) |
|
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) |
|
|
|
def _convert_openclip_transformer(module: Transformer, prefix): |
|
for i, block in enumerate(module.resblocks.children()): |
|
block_prefix = f'{prefix}encoderblock_{i}/' |
|
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' |
|
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
|
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
|
block.attn.in_proj_weight.copy_(torch.cat([ |
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
|
block.attn.in_proj_bias.copy_(torch.cat([ |
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
|
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
|
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
|
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) |
|
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) |
|
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) |
|
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) |
|
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) |
|
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) |
|
|
|
def _convert_openclip_txt(module: TextTransformer, prefix): |
|
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) |
|
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) |
|
module.positional_embedding.copy_(pos_embed_w) |
|
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') |
|
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) |
|
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) |
|
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) |
|
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) |
|
|
|
_convert_timm_img(model.visual.trunk, 'params/img/') |
|
_convert_openclip_txt(model.text, 'params/txt/') |
|
model.logit_bias.copy_(_n2p(w['params/b'])[0]) |
|
model.logit_scale.copy_(_n2p(w['params/t'])[0]) |
|
|
|
|
|
@torch.no_grad() |
|
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): |
|
|
|
def _convert_timm_img(state_dict): |
|
if fastvit: |
|
from timm.models.fastvit import checkpoint_filter_fn |
|
else: |
|
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn |
|
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) |
|
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} |
|
return timm_state_dict |
|
|
|
def _convert_openclip_txt(state_dict, prefix='text_encoder.'): |
|
text_dict = {} |
|
for k, v in state_dict.items(): |
|
if not k.startswith(prefix): |
|
continue |
|
k = k.replace(prefix, '') |
|
k = k.replace('projection_layer', 'text_projection') |
|
k = k.replace('embedding_layer', 'token_embedding') |
|
if k.startswith('positional_embedding.pos_embed.pos_embed'): |
|
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') |
|
v = v.squeeze() |
|
k = k.replace('final_layer_norm', 'ln_final') |
|
k = k.replace('pre_norm_mha.0', 'ln_1') |
|
k = k.replace('pre_norm_mha.1', 'attn') |
|
k = k.replace('pre_norm_ffn.0', 'ln_2') |
|
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') |
|
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') |
|
k = k.replace('qkv_proj.weight', 'in_proj_weight') |
|
k = k.replace('qkv_proj.bias', 'in_proj_bias') |
|
k = k.replace('transformer.', 'transformer.resblocks.') |
|
text_dict['text.' + k] = v |
|
return text_dict |
|
|
|
image_dict = _convert_timm_img(state_dict) |
|
text_dict = _convert_openclip_txt(state_dict) |
|
out_dict = {**image_dict, **text_dict} |
|
out_dict['logit_scale'] = state_dict['logit_scale'] |
|
return out_dict |
|
|
|
|
|
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): |
|
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: |
|
|
|
state_dict = convert_mobile_clip_state_dict(model, state_dict) |
|
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: |
|
|
|
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) |
|
return state_dict |
|
|