John6666's picture
Upload 351 files
e84842d verified
raw
history blame
43.4 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Based on https://github.com/mlfoundations/open_clip
"""
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import datetime
import json
import logging
import os
import re
import time
import warnings
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.common.utils import get_abs_path
from lavis.models.base_model import BaseModel
from lavis.models.clip_models.clip_outputs import ClipOutput, ClipOutputFeatures
from lavis.models.clip_models.timm_model import TimmModel
from lavis.models.clip_models.transform import image_transform
from lavis.models.clip_models.utils import freeze_batch_norm_2d
from lavis.tasks.multimodal_classification import MultimodalClassificationTask
from torch import nn
from .pretrained import (
download_pretrained,
get_pretrained_url,
list_pretrained_tag_models,
)
_MODEL_CONFIG_PATHS = [Path(__file__).parent.parent.parent / f"configs/models/clip/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict(
[
("-1", nn.AvgPool2d(stride)),
(
"0",
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False,
),
),
("1", nn.BatchNorm2d(planes * self.expansion)),
]
)
)
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
2, 0, 1
) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert (
unlocked_groups == 0
), "partial locking not currently supported for this model"
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
def stem(self, x):
for conv, bn in [
(self.conv1, self.bn1),
(self.conv2, self.bn2),
(self.conv3, self.bn3),
]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", act_layer()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(width, heads, act_layer=act_layer)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisualTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
act_layer: Callable = nn.GELU,
):
super().__init__()
self.image_size = image_size
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads, act_layer=act_layer)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert (
unlocked_groups == 0
), "partial locking not currently supported for this model"
for param in self.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
timm_model_name: str = (
None # a valid model name overrides layers, width, patch_size
)
timm_model_pretrained: bool = (
False # use (imagenet) pretrained weights for named model
)
timm_pool: str = (
"avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
)
timm_proj: str = (
"linear" # linear projection for timm model output ('linear', 'mlp', '')
)
@dataclass
class CLIPTextCfg:
context_length: int
vocab_size: int
width: int
heads: int
layers: int
@registry.register_model("clip")
@registry.register_model("clip_feature_extractor")
class CLIP(BaseModel):
PRETRAINED_MODEL_CONFIG_DICT = {
"ViT-B-32": "configs/models/clip_vit_base32.yaml",
"ViT-B-16": "configs/models/clip_vit_base16.yaml",
"ViT-L-14": "configs/models/clip_vit_large14.yaml",
"ViT-L-14-336": "configs/models/clip_vit_large14_336.yaml",
"RN50": "configs/models/clip_resnet50.yaml",
}
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
):
from .tokenizer import tokenize
super().__init__()
self.tokenizer = tokenize
self._loss = None
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
self.context_length = text_cfg.context_length
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
self.visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
act_layer = (
nn.GELU
) # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // 64
self.visual = VisualTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
output_dim=embed_dim,
act_layer=act_layer,
)
self.transformer = Transformer(
width=text_cfg.width,
layers=text_cfg.layers,
heads=text_cfg.heads,
act_layer=act_layer,
)
self.vocab_size = text_cfg.vocab_size
self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, text_cfg.width)
)
self.ln_final = LayerNorm(text_cfg.width)
self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
self.prompt_templates = openai_imagenet_template
self.classifier = None
self.init_parameters()
@property
def loss(self):
if self._loss is None:
from lavis.models.clip_models.loss import ClipLoss
from torch import distributed as dist
self._loss = ClipLoss(
world_size=dist.get_world_size(),
rank=dist.get_rank(),
local_loss=False,
gather_with_grad=False,
use_horovod=False,
)
return self._loss
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
if hasattr(self.visual, "init_parameters"):
self.visual.init_parameters()
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(
unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats
)
def encode_image(self, image):
return self.visual(image)
def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
# def forward(self, image, text):
def forward(self, samples):
image = samples.get("image")
text = samples.get("text_input")
if text is not None:
text = self.tokenizer(text).to(self.device)
if image is None:
return self.encode_text(text)
elif text is None:
return self.encode_image(image)
image_embeds = self.encode_image(image)
image_features = F.normalize(image_embeds, dim=-1)
text_embeds = self.encode_text(text)
text_features = F.normalize(text_embeds, dim=-1)
loss = self.loss(image_features, text_features, self.logit_scale.exp())
# return image_features, text_features, self.logit_scale.exp()
# return {"loss": loss}
return ClipOutput(
intermediate_output=ClipOutputFeatures(
image_embeds=image_embeds,
image_embeds_proj=image_features,
text_embeds=text_embeds,
text_embeds_proj=text_features,
),
loss=loss,
logit_scale_exp=self.logit_scale.exp(),
)
def extract_features(self, samples):
"""
Extract features from the model for samples.
Keys allowed are "image" and "text_input" in samples.
If either key is missing, the corresponding features are not extracted.
Args:
samples: dict of samples to extract features from.
Returns:
ClipOutputFeatures object with features for the samples.
"""
image = samples.get("image")
text = samples.get("text_input")
if text is not None:
text = self.tokenizer(text).to(self.device)
if image is None:
return self.encode_text(text)
elif text is None:
return self.encode_image(image)
image_embeds = self.encode_image(image)
image_features = F.normalize(image_embeds, dim=-1)
text_embeds = self.encode_text(text)
text_features = F.normalize(text_embeds, dim=-1)
return ClipOutputFeatures(
image_embeds=image_embeds,
image_embeds_proj=image_features,
text_embeds=text_embeds,
text_embeds_proj=text_features,
)
def predict(self, samples):
image = samples["image"]
targets = samples["label"]
image_features = self.encode_image(image)
image_features = F.normalize(image_features, dim=-1)
logits = 100.0 * image_features @ self.classifier
return {"predictions": logits, "targets": targets}
def before_evaluation(self, dataset, task_type, **kwargs):
if task_type == MultimodalClassificationTask:
self.classifier = self.zero_shot_classifier(
classnames=dataset.classnames,
templates=self.prompt_templates,
)
def zero_shot_classifier(self, classnames, templates):
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
texts = [
template(classname) for template in templates
] # format with class
texts = self.tokenizer(texts).to(self.device) # tokenize
class_embeddings = self.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
return zeroshot_weights
@classmethod
def default_config_path(cls, model_type="base"):
model_type = "ViT-B-32" if model_type == "base" else model_type
assert (
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
), "Unknown model type {}. \n Available types: {}".format(
model_type, cls.PRETRAINED_MODEL_CONFIG_DICT.keys()
)
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
@classmethod
def from_config(cls, cfg=None):
model_name = cfg.model_type
pretrained = cfg.pretrained
precision = cfg.get("precision", "fp32")
return create_model(
model_name=model_name, pretrained=pretrained, precision=precision
)
def zero_shot_predict(self, image_path, categories):
assert isinstance(
categories, list
), f"categories must be a list, got {type(categories)}."
assert os.path.exists(image_path), f"File {image_path} does not exist."
from lavis.processors.clip_processors import ClipImageEvalProcessor
from PIL import Image
image_preprocess = ClipImageEvalProcessor()
image = image_preprocess(Image.open(image_path)).unsqueeze(0)
text = self.tokenizer(categories)
with torch.no_grad():
image_features = self.encode_image(image)
text_features = self.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
def compute_sim_matrix(self, data_loader, **kwargs):
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_features = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = self.tokenizer(text).to(self.device)
text_feat = self.encode_text(text_input)
text_feat = F.normalize(text_feat, dim=-1)
text_features.append(text_feat)
text_features = torch.cat(text_features, dim=0)
image_features = []
for samples in data_loader:
image = samples["image"]
image = image.to(self.device)
image_feat = self.encode_image(image)
image_feat = F.normalize(image_feat, dim=-1)
image_features.append(image_feat)
image_features = torch.cat(image_features, dim=0)
sims_matrix_i2t = image_features @ text_features.t()
sims_matrix_t2i = sims_matrix_i2t.t()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return sims_matrix_i2t.cpu().numpy(), sims_matrix_t2i.cpu().numpy()
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
"in_proj_bias",
"bias_k",
"bias_v",
]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model_from_openai_state_dict(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round(
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"visual.layer{b}")
)
)
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
)
vision_patch_size = None
assert (
output_width**2 + 1
== state_dict["visual.attnpool.positional_embedding"].shape[0]
)
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"transformer.resblocks")
)
)
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=True, # OpenAI models were trained with QuickGELU
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model)
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device("cpu")):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros(
(batch_size, model.context_length), dtype=torch.int, device=device
)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,),
),
)
model.visual.image_size = image_size
return
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = (".json",)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f"*{ext}"))
for cf in config_files:
with open(cf, "r") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {
k: v
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
}
_rescan_model_configs() # initial populate of model config registry
def load_state_dict(checkpoint_path: str, map_location="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def create_model(
model_name: str,
pretrained: str = "",
precision: str = "fp32",
device: torch.device = torch.device("cpu"),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
):
model_name = model_name.replace(
"/", "-"
) # for callers using old naming with / in ViT names
if pretrained.lower() == "openai":
logging.info(f"Loading pretrained {model_name} from OpenAI.")
model = load_openai_model(model_name, device=device, jit=jit)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if precision == "amp" or precision == "fp32":
model = model.float()
else:
logging.info(f"No pretrained weights loaded for {model_name} model.")
if model_name in _MODEL_CONFIGS:
logging.info(f"Loading {model_name} model config.")
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
else:
logging.error(
f"Model config for {model_name} not found; available models {list_models()}."
)
raise RuntimeError(f"Model config for {model_name} not found.")
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if pretrained_image:
if "timm_model_name" in model_cfg.get("vision_cfg", {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg["vision_cfg"]["timm_model_pretrained"] = True
else:
assert (
False
), "pretrained image towers currently only supported for timm models"
model = CLIP(**model_cfg)
if pretrained:
checkpoint_path = ""
url = get_pretrained_url(model_name, pretrained)
if url:
checkpoint_path = download_pretrained(url)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
model.load_state_dict(load_state_dict(checkpoint_path))
else:
logging.warning(
f"Pretrained weights ({pretrained}) not found for model {model_name}."
)
raise RuntimeError(
f"Pretrained weights ({pretrained}) not found for model {model_name}."
)
model.to(device=device)
if precision == "fp16":
assert device.type != "cpu"
convert_weights_to_fp16(model)
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: str = "",
precision: str = "fp32",
device: torch.device = torch.device("cpu"),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
):
model = create_model(
model_name,
pretrained,
precision,
device,
jit,
force_quick_gelu=force_quick_gelu,
pretrained_image=pretrained_image,
)
preprocess_train = image_transform(model.visual.image_size, is_train=True)
preprocess_val = image_transform(model.visual.image_size, is_train=False)
return model, preprocess_train, preprocess_val
def list_models():
"""enumerate available model architectures based on config files"""
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
"""add model config path or file and update registry"""
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_tag_models("openai")
def load_openai_model(
name: str,
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
jit=True,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if get_pretrained_url(name, "openai"):
model_path = download_pretrained(get_pretrained_url(name, "openai"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {list_openai_models()}"
)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
)
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
try:
model = build_model_from_openai_state_dict(
state_dict or model.state_dict()
).to(device)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd).to(device)
if str(device) == "cpu":
model.float()
return model
# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
)
device_node = [
n
for n in device_holder.graph.findAllNodes("prim::Constant")
if "Device" in repr(n)
][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith(
"cuda"
):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[]
)
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [
1,
2,
]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
# ensure image_size attr available at consistent location for both jit and non-jit
model.visual.image_size = model.input_resolution.item()
return model
openai_imagenet_template = [
lambda c: f"a bad photo of a {c}.",
lambda c: f"a photo of many {c}.",
lambda c: f"a sculpture of a {c}.",
lambda c: f"a photo of the hard to see {c}.",
lambda c: f"a low resolution photo of the {c}.",
lambda c: f"a rendering of a {c}.",
lambda c: f"graffiti of a {c}.",
lambda c: f"a bad photo of the {c}.",
lambda c: f"a cropped photo of the {c}.",
lambda c: f"a tattoo of a {c}.",
lambda c: f"the embroidered {c}.",
lambda c: f"a photo of a hard to see {c}.",
lambda c: f"a bright photo of a {c}.",
lambda c: f"a photo of a clean {c}.",
lambda c: f"a photo of a dirty {c}.",
lambda c: f"a dark photo of the {c}.",
lambda c: f"a drawing of a {c}.",
lambda c: f"a photo of my {c}.",
lambda c: f"the plastic {c}.",
lambda c: f"a photo of the cool {c}.",
lambda c: f"a close-up photo of a {c}.",
lambda c: f"a black and white photo of the {c}.",
lambda c: f"a painting of the {c}.",
lambda c: f"a painting of a {c}.",
lambda c: f"a pixelated photo of the {c}.",
lambda c: f"a sculpture of the {c}.",
lambda c: f"a bright photo of the {c}.",
lambda c: f"a cropped photo of a {c}.",
lambda c: f"a plastic {c}.",
lambda c: f"a photo of the dirty {c}.",
lambda c: f"a jpeg corrupted photo of a {c}.",
lambda c: f"a blurry photo of the {c}.",
lambda c: f"a photo of the {c}.",
lambda c: f"a good photo of the {c}.",
lambda c: f"a rendering of the {c}.",
lambda c: f"a {c} in a video game.",
lambda c: f"a photo of one {c}.",
lambda c: f"a doodle of a {c}.",
lambda c: f"a close-up photo of the {c}.",
lambda c: f"a photo of a {c}.",
lambda c: f"the origami {c}.",
lambda c: f"the {c} in a video game.",
lambda c: f"a sketch of a {c}.",
lambda c: f"a doodle of the {c}.",
lambda c: f"a origami {c}.",
lambda c: f"a low resolution photo of a {c}.",
lambda c: f"the toy {c}.",
lambda c: f"a rendition of the {c}.",
lambda c: f"a photo of the clean {c}.",
lambda c: f"a photo of a large {c}.",
lambda c: f"a rendition of a {c}.",
lambda c: f"a photo of a nice {c}.",
lambda c: f"a photo of a weird {c}.",
lambda c: f"a blurry photo of a {c}.",
lambda c: f"a cartoon {c}.",
lambda c: f"art of a {c}.",
lambda c: f"a sketch of the {c}.",
lambda c: f"a embroidered {c}.",
lambda c: f"a pixelated photo of a {c}.",
lambda c: f"itap of the {c}.",
lambda c: f"a jpeg corrupted photo of the {c}.",
lambda c: f"a good photo of a {c}.",
lambda c: f"a plushie {c}.",
lambda c: f"a photo of the nice {c}.",
lambda c: f"a photo of the small {c}.",
lambda c: f"a photo of the weird {c}.",
lambda c: f"the cartoon {c}.",
lambda c: f"art of the {c}.",
lambda c: f"a drawing of the {c}.",
lambda c: f"a photo of the large {c}.",
lambda c: f"a black and white photo of a {c}.",
lambda c: f"the plushie {c}.",
lambda c: f"a dark photo of a {c}.",
lambda c: f"itap of a {c}.",
lambda c: f"graffiti of the {c}.",
lambda c: f"a toy {c}.",
lambda c: f"itap of my {c}.",
lambda c: f"a photo of a cool {c}.",
lambda c: f"a photo of a small {c}.",
lambda c: f"a tattoo of the {c}.",
]