File size: 6,071 Bytes
ece766c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from clip.model import CLIP, convert_weights
from clip.simple_tokenizer import SimpleTokenizer, default_bpe
"""===== Monkey-Patching original CLIP for JIT compile ====="""
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = F.layer_norm(
x.type(torch.float32),
self.normalized_shape,
self.weight,
self.bias,
self.eps,
)
return ret.type(orig_type)
clip.model.LayerNorm = LayerNorm
delattr(clip.model.CLIP, "forward")
"""===== End of Monkey-Patching ====="""
class CustomizedCLIP(CLIP):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.jit.export
def encode_image(self, image):
return self.visual(image)
@torch.jit.export
def encode_text(self, text):
# re-define this function to return unpooled text features
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
x_seq = x
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x_out, x_seq
@torch.jit.ignore
def forward(self, image, text):
super().forward(image, text)
@classmethod
def load_from_checkpoint(cls, ckpt_path: str):
state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round(
(state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith(f"visual.layer{b}")
)
)
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
)
vision_patch_size = None
assert (
output_width**2 + 1
== state_dict["visual.attnpool.positional_embedding"].shape[0]
)
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split(".")[2]
for k in state_dict
if k.startswith("transformer.resblocks")
)
)
model = cls(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
model.eval()
model.float()
return model
class CustomizedTokenizer(SimpleTokenizer):
def __init__(self):
super().__init__(bpe_path=default_bpe())
self.sot_token = self.encoder["<|startoftext|>"]
self.eot_token = self.encoder["<|endoftext|>"]
def padded_tokens_and_mask(self, texts, text_ctx):
assert isinstance(texts, list) and all(
isinstance(elem, str) for elem in texts
), "texts should be a list of strings"
all_tokens = [
[self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
]
mask = [
[True] * min(text_ctx, len(tokens))
+ [False] * max(text_ctx - len(tokens), 0)
for tokens in all_tokens
]
mask = torch.tensor(mask, dtype=torch.bool)
result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > text_ctx:
tokens = tokens[:text_ctx]
tokens[-1] = self.eot_token
result[i, : len(tokens)] = torch.tensor(tokens)
return result, mask
|