AeroXi's picture
Upload folder using huggingface_hub
ece766c
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from clip.model import CLIP, convert_weights
from clip.simple_tokenizer import SimpleTokenizer, default_bpe
"""===== Monkey-Patching original CLIP for JIT compile ====="""
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = F.layer_norm(
x.type(torch.float32),
self.normalized_shape,
self.weight,
self.bias,
self.eps,
)
return ret.type(orig_type)
clip.model.LayerNorm = LayerNorm
delattr(clip.model.CLIP, "forward")
"""===== End of Monkey-Patching ====="""
class CustomizedCLIP(CLIP):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.jit.export
def encode_image(self, image):
return self.visual(image)
@torch.jit.export
def encode_text(self, text):
# re-define this function to return unpooled text features
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
x_seq = x
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x_out, x_seq
@torch.jit.ignore
def forward(self, image, text):
super().forward(image, text)
@classmethod
def load_from_checkpoint(cls, ckpt_path: str):
state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round(
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"visual.layer{b}")
)
)
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
)
vision_patch_size = None
assert (
output_width**2 + 1
== state_dict["visual.attnpool.positional_embedding"].shape[0]
)
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith("transformer.resblocks")
)
)
model = cls(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
model.eval()
model.float()
return model
class CustomizedTokenizer(SimpleTokenizer):
def __init__(self):
super().__init__(bpe_path=default_bpe())
self.sot_token = self.encoder["<|startoftext|>"]
self.eot_token = self.encoder["<|endoftext|>"]
def padded_tokens_and_mask(self, texts, text_ctx):
assert isinstance(texts, list) and all(
isinstance(elem, str) for elem in texts
), "texts should be a list of strings"
all_tokens = [
[self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
]
mask = [
[True] * min(text_ctx, len(tokens))
+ [False] * max(text_ctx - len(tokens), 0)
for tokens in all_tokens
]
mask = torch.tensor(mask, dtype=torch.bool)
result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > text_ctx:
tokens = tokens[:text_ctx]
tokens[-1] = self.eot_token
result[i, : len(tokens)] = torch.tensor(tokens)
return result, mask