|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import clip |
|
|
|
from clip.model import CLIP, convert_weights |
|
from clip.simple_tokenizer import SimpleTokenizer, default_bpe |
|
|
|
|
|
"""===== Monkey-Patching original CLIP for JIT compile =====""" |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = F.layer_norm( |
|
x.type(torch.float32), |
|
self.normalized_shape, |
|
self.weight, |
|
self.bias, |
|
self.eps, |
|
) |
|
return ret.type(orig_type) |
|
|
|
|
|
clip.model.LayerNorm = LayerNorm |
|
delattr(clip.model.CLIP, "forward") |
|
|
|
"""===== End of Monkey-Patching =====""" |
|
|
|
|
|
class CustomizedCLIP(CLIP): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
@torch.jit.export |
|
def encode_image(self, image): |
|
return self.visual(image) |
|
|
|
@torch.jit.export |
|
def encode_text(self, text): |
|
|
|
|
|
x = self.token_embedding(text).type(self.dtype) |
|
|
|
x = x + self.positional_embedding.type(self.dtype) |
|
x = x.permute(1, 0, 2) |
|
x = self.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_final(x).type(self.dtype) |
|
|
|
x_seq = x |
|
|
|
|
|
x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|
|
|
return x_out, x_seq |
|
|
|
@torch.jit.ignore |
|
def forward(self, image, text): |
|
super().forward(image, text) |
|
|
|
@classmethod |
|
def load_from_checkpoint(cls, ckpt_path: str): |
|
state_dict = torch.load(ckpt_path, map_location="cpu").state_dict() |
|
|
|
vit = "visual.proj" in state_dict |
|
if vit: |
|
vision_width = state_dict["visual.conv1.weight"].shape[0] |
|
vision_layers = len( |
|
[ |
|
k |
|
for k in state_dict.keys() |
|
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") |
|
] |
|
) |
|
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] |
|
grid_size = round( |
|
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 |
|
) |
|
image_resolution = vision_patch_size * grid_size |
|
else: |
|
counts: list = [ |
|
len( |
|
set( |
|
k.split(".")[2] |
|
for k in state_dict |
|
if k.startswith(f"visual.layer{b}") |
|
) |
|
) |
|
for b in [1, 2, 3, 4] |
|
] |
|
vision_layers = tuple(counts) |
|
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] |
|
output_width = round( |
|
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 |
|
) |
|
vision_patch_size = None |
|
assert ( |
|
output_width**2 + 1 |
|
== state_dict["visual.attnpool.positional_embedding"].shape[0] |
|
) |
|
image_resolution = output_width * 32 |
|
|
|
embed_dim = state_dict["text_projection"].shape[1] |
|
context_length = state_dict["positional_embedding"].shape[0] |
|
vocab_size = state_dict["token_embedding.weight"].shape[0] |
|
transformer_width = state_dict["ln_final.weight"].shape[0] |
|
transformer_heads = transformer_width // 64 |
|
transformer_layers = len( |
|
set( |
|
k.split(".")[2] |
|
for k in state_dict |
|
if k.startswith("transformer.resblocks") |
|
) |
|
) |
|
|
|
model = cls( |
|
embed_dim, |
|
image_resolution, |
|
vision_layers, |
|
vision_width, |
|
vision_patch_size, |
|
context_length, |
|
vocab_size, |
|
transformer_width, |
|
transformer_heads, |
|
transformer_layers, |
|
) |
|
|
|
for key in ["input_resolution", "context_length", "vocab_size"]: |
|
if key in state_dict: |
|
del state_dict[key] |
|
|
|
convert_weights(model) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
model.float() |
|
return model |
|
|
|
|
|
class CustomizedTokenizer(SimpleTokenizer): |
|
def __init__(self): |
|
super().__init__(bpe_path=default_bpe()) |
|
|
|
self.sot_token = self.encoder["<|startoftext|>"] |
|
self.eot_token = self.encoder["<|endoftext|>"] |
|
|
|
def padded_tokens_and_mask(self, texts, text_ctx): |
|
assert isinstance(texts, list) and all( |
|
isinstance(elem, str) for elem in texts |
|
), "texts should be a list of strings" |
|
|
|
all_tokens = [ |
|
[self.sot_token] + self.encode(text) + [self.eot_token] for text in texts |
|
] |
|
|
|
mask = [ |
|
[True] * min(text_ctx, len(tokens)) |
|
+ [False] * max(text_ctx - len(tokens), 0) |
|
for tokens in all_tokens |
|
] |
|
mask = torch.tensor(mask, dtype=torch.bool) |
|
result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int) |
|
for i, tokens in enumerate(all_tokens): |
|
if len(tokens) > text_ctx: |
|
tokens = tokens[:text_ctx] |
|
tokens[-1] = self.eot_token |
|
result[i, : len(tokens)] = torch.tensor(tokens) |
|
|
|
return result, mask |
|
|