|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Convert SigLIP checkpoints from the original repository. |
|
|
|
URL: https://github.com/google-research/big_vision/tree/main |
|
""" |
|
|
|
|
|
import argparse |
|
import collections |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from numpy import load |
|
from PIL import Image |
|
|
|
from transformers import SiglipConfig, SiglipModel |
|
from transformers.utils import logging |
|
|
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def get_siglip_config(model_name): |
|
config = SiglipConfig() |
|
|
|
|
|
if "base" in model_name: |
|
config.vision_config.image_size = 224 |
|
config.vision_config.patch_size = 16 |
|
config.text_config.vocab_size = 32000 |
|
config.text_config.hidden_size = 768 |
|
config.text_config.intermediate_size = 3072 |
|
config.text_config.max_position_embeddings = 64 |
|
config.text_config.num_attention_heads = 12 |
|
elif "large" in model_name: |
|
config.vision_config.hidden_size = 1024 |
|
config.vision_config.num_hidden_layers = 24 |
|
config.vision_config.num_attention_heads = 16 |
|
else: |
|
raise ValueError("Model not supported") |
|
|
|
return config |
|
|
|
|
|
def create_rename_keys(config): |
|
rename_keys = [] |
|
|
|
|
|
|
|
|
|
rename_keys.append(("params/img/embedding/kernel", "vision_model.vision_model.embeddings.patch_embedding.weight")) |
|
rename_keys.append(("params/img/embedding/bias", "vision_model.vision_model.embeddings.patch_embedding.bias")) |
|
rename_keys.append(("params/img/pos_embedding", "vision_model.vision_model.embeddings.position_embedding.weight")) |
|
|
|
for i in range(config.vision_config.num_hidden_layers): |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.vision_model.encoder.layers.{i}.layer_norm1.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.vision_model.encoder.layers.{i}.layer_norm1.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.vision_model.encoder.layers.{i}.layer_norm2.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.vision_model.encoder.layers.{i}.layer_norm2.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc1.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc1.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc2.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc2.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.k_proj.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.k_proj.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.v_proj.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.v_proj.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.q_proj.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.q_proj.bias")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.out_proj.weight")) |
|
rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.out_proj.bias")) |
|
|
|
rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.vision_model.post_layernorm.weight")) |
|
rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.vision_model.post_layernorm.bias")) |
|
|
|
rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.vision_model.head.probe")) |
|
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.vision_model.head.layernorm.weight")) |
|
rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.vision_model.head.layernorm.bias")) |
|
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.vision_model.head.mlp.fc1.weight")) |
|
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.vision_model.head.mlp.fc1.bias")) |
|
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.vision_model.head.mlp.fc2.weight")) |
|
rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.vision_model.head.mlp.fc2.bias")) |
|
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.vision_model.head.attention.out_proj.weight")) |
|
rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.vision_model.head.attention.out_proj.bias")) |
|
|
|
|
|
|
|
rename_keys.append(("params/txt/Embed_0/embedding", "text_model.text_model.embeddings.token_embedding.weight")) |
|
rename_keys.append(("params/txt/pos_embedding", "text_model.text_model.embeddings.position_embedding.weight")) |
|
|
|
for i in range(config.text_config.num_hidden_layers): |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.text_model.encoder.layers.{i}.layer_norm1.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.text_model.encoder.layers.{i}.layer_norm1.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.text_model.encoder.layers.{i}.layer_norm2.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.text_model.encoder.layers.{i}.layer_norm2.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.text_model.encoder.layers.{i}.mlp.fc1.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.text_model.encoder.layers.{i}.mlp.fc1.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.text_model.encoder.layers.{i}.mlp.fc2.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.text_model.encoder.layers.{i}.mlp.fc2.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.k_proj.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.k_proj.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.v_proj.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.v_proj.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.q_proj.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.q_proj.bias")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.out_proj.weight")) |
|
rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.out_proj.bias")) |
|
|
|
rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.text_model.final_layer_norm.weight")) |
|
rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.text_model.final_layer_norm.bias")) |
|
rename_keys.append(("params/txt/head/kernel", "text_model.text_model.head.weight")) |
|
rename_keys.append(("params/txt/head/bias", "text_model.text_model.head.bias")) |
|
|
|
|
|
rename_keys.append(("params/t", "temperature")) |
|
rename_keys.append(("params/b", "bias")) |
|
|
|
|
|
return rename_keys |
|
|
|
|
|
def rename_key(dct, old, new, config): |
|
val = dct.pop(old) |
|
|
|
if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new: |
|
val = val.reshape(-1, config.vision_config.hidden_size) |
|
if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new: |
|
val = val.reshape(-1, config.text_config.hidden_size) |
|
|
|
if "patch_embedding.weight" in new: |
|
val = val.transpose(3, 2, 0, 1) |
|
elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new: |
|
val = val.T |
|
|
|
if "position_embedding" in new and "vision" in new: |
|
val = val.reshape(-1, config.vision_config.hidden_size) |
|
if "position_embedding" in new and "text" in new: |
|
val = val.reshape(-1, config.text_config.hidden_size) |
|
|
|
if new.endswith("bias"): |
|
val = val.reshape(-1) |
|
|
|
dct[new] = torch.from_numpy(val) |
|
|
|
|
|
def read_in_q_k_v_head(state_dict, config): |
|
|
|
key_proj_weight = ( |
|
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel") |
|
.reshape(-1, config.vision_config.hidden_size) |
|
.T |
|
) |
|
key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1) |
|
value_proj_weight = ( |
|
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel") |
|
.reshape(-1, config.vision_config.hidden_size) |
|
.T |
|
) |
|
value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1) |
|
query_proj_weight = ( |
|
state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel") |
|
.reshape(-1, config.vision_config.hidden_size) |
|
.T |
|
) |
|
query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1) |
|
|
|
|
|
state_dict["vision_model.vision_model.head.attention.in_proj_weight"] = torch.from_numpy( |
|
np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0) |
|
) |
|
state_dict["vision_model.vision_model.head.attention.in_proj_bias"] = torch.from_numpy( |
|
np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0) |
|
) |
|
|
|
|
|
|
|
def prepare_img(): |
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
return image |
|
|
|
|
|
def flatten_nested_dict(params, parent_key="", sep="/"): |
|
items = [] |
|
|
|
for k, v in params.items(): |
|
new_key = parent_key + sep + k if parent_key else k |
|
|
|
if isinstance(v, collections.abc.MutableMapping): |
|
items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) |
|
else: |
|
items.append((new_key, v)) |
|
return dict(items) |
|
|
|
|
|
@torch.no_grad() |
|
def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): |
|
""" |
|
Copy/paste/tweak model's weights to our SigLIP structure. |
|
""" |
|
|
|
|
|
config = get_siglip_config(model_name) |
|
|
|
|
|
data = load("/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz") |
|
state_dict = flatten_nested_dict(data) |
|
|
|
|
|
rename_keys = create_rename_keys(config) |
|
for src, dest in rename_keys: |
|
rename_key(state_dict, src, dest, config) |
|
|
|
|
|
read_in_q_k_v_head(state_dict, config) |
|
|
|
|
|
model = SiglipModel(config).eval() |
|
model.load_state_dict(state_dict) |
|
|
|
print("Original temperature:", data["params/t"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="pixel_values_siglip.npy", repo_type="dataset") |
|
pixel_values = np.load(filepath) |
|
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2) |
|
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="input_ids_siglip.npy", repo_type="dataset") |
|
input_ids = np.load(filepath) |
|
input_ids = torch.from_numpy(input_ids) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, pixel_values=pixel_values) |
|
|
|
|
|
expected_slice = torch.tensor( |
|
[[-2.9621, -2.1672, -1.7837], [-0.2713, 0.2910, -10.6595], [-13.6617, -13.1611, -17.4408]] |
|
) |
|
assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4) |
|
print("Looks ok!") |
|
|
|
if pytorch_dump_folder_path is not None: |
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |
|
print(f"Saving model {model_name} to {pytorch_dump_folder_path}") |
|
model.save_pretrained(pytorch_dump_folder_path) |
|
|
|
|
|
|
|
if push_to_hub: |
|
model.push_to_hub(f"nielsr/{model_name}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model_name", |
|
default="siglip-base-patch16-224", |
|
type=str, |
|
choices=["siglip-base-patch16-224"], |
|
help="Name of the model you'd like to convert.", |
|
) |
|
parser.add_argument( |
|
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." |
|
) |
|
parser.add_argument( |
|
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." |
|
) |
|
|
|
args = parser.parse_args() |
|
convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) |
|
|