Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
from functools import partial | |
from lib.model_zoo.common.get_model import register | |
symbol = 'clip' | |
class AbstractEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def encode(self, *args, **kwargs): | |
raise NotImplementedError | |
from transformers import CLIPTokenizer, CLIPTextModel | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
class CLIPTextContextEncoderSDv1(AbstractEncoder): | |
"""Uses the CLIP transformer encoder for text (from huggingface)""" | |
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True): # clip-vit-base-patch32 | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.transformer = CLIPTextModel.from_pretrained(version) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
with torch.no_grad(): | |
batch_encoding = self.tokenizer( | |
text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"].to(self.device) | |
max_token_n = self.transformer.text_model.embeddings.position_ids.shape[1] | |
positional_ids = torch.arange(max_token_n)[None].to(self.device) | |
outputs = self.transformer( | |
input_ids=tokens, | |
position_ids=positional_ids, ) | |
z = outputs.last_hidden_state | |
return z | |
def encode(self, text): | |
return self(text) | |
############################# | |
# copyed from justin's code # | |
############################# | |
class CLIPImageContextEncoderJustin(AbstractEncoder): | |
""" | |
Uses the CLIP image encoder. | |
""" | |
def __init__( | |
self, | |
model='ViT-L/14', | |
jit=False, | |
device='cuda' if torch.cuda.is_available() else 'cpu', | |
antialias=False, | |
): | |
super().__init__() | |
from . import clip_justin | |
self.model, _ = clip_justin.load(name=model, device=device, jit=jit) | |
self.device = device | |
self.antialias = antialias | |
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) | |
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) | |
# I didn't call this originally, but seems like it was frozen anyway | |
self.freeze() | |
def freeze(self): | |
self.transformer = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def preprocess(self, x): | |
import kornia | |
# Expects inputs in the range -1, 1 | |
x = kornia.geometry.resize(x, (224, 224), | |
interpolation='bicubic',align_corners=True, | |
antialias=self.antialias) | |
x = (x + 1.) / 2. | |
# renormalize according to clip | |
x = kornia.enhance.normalize(x, self.mean, self.std) | |
return x | |
def forward(self, x): | |
# x is assumed to be in range [-1,1] | |
return self.model.encode_image(self.preprocess(x)).float() | |
def encode(self, im): | |
return self(im).unsqueeze(1) | |
############### | |
# for vd next # | |
############### | |
from transformers import CLIPModel | |
class CLIPTextContextEncoder(AbstractEncoder): | |
def __init__(self, | |
version="openai/clip-vit-large-patch14", | |
max_length=77, | |
fp16=False, ): | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.model = CLIPModel.from_pretrained(version) | |
self.max_length = max_length | |
self.fp16 = fp16 | |
self.freeze() | |
def get_device(self): | |
# A trick to get device | |
return self.model.text_projection.weight.device | |
def freeze(self): | |
self.model = self.model.eval() | |
self.train = disabled_train | |
for param in self.parameters(): | |
param.requires_grad = False | |
def encode(self, text): | |
batch_encoding = self.tokenizer( | |
text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"].to(self.get_device()) | |
outputs = self.model.text_model(input_ids=tokens) | |
z = self.model.text_projection(outputs.last_hidden_state) | |
z_pooled = self.model.text_projection(outputs.pooler_output) | |
z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True) | |
return z | |
from transformers import CLIPProcessor | |
class CLIPImageContextEncoder(AbstractEncoder): | |
def __init__(self, | |
version="openai/clip-vit-large-patch14", | |
fp16=False, ): | |
super().__init__() | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.processor = CLIPProcessor.from_pretrained(version) | |
self.model = CLIPModel.from_pretrained(version) | |
self.fp16 = fp16 | |
self.freeze() | |
def get_device(self): | |
# A trick to get device | |
return self.model.text_projection.weight.device | |
def freeze(self): | |
self.model = self.model.eval() | |
self.train = disabled_train | |
for param in self.parameters(): | |
param.requires_grad = False | |
def _encode(self, images): | |
if isinstance(images, torch.Tensor): | |
import torchvision.transforms as tvtrans | |
images = [tvtrans.ToPILImage()(i) for i in images] | |
inputs = self.processor(images=images, return_tensors="pt") | |
pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values'] | |
pixels = pixels.to(self.get_device()) | |
outputs = self.model.vision_model(pixel_values=pixels) | |
z = outputs.last_hidden_state | |
z = self.model.vision_model.post_layernorm(z) | |
z = self.model.visual_projection(z) | |
z_pooled = z[:, 0:1] | |
z = z / torch.norm(z_pooled, dim=-1, keepdim=True) | |
return z | |
def _encode_wmask(self, images, masks): | |
assert isinstance(masks, torch.Tensor) | |
assert (len(masks.shape)==4) and (masks.shape[1]==1) | |
masks = torch.clamp(masks, 0, 1) | |
masked_images = images*masks | |
masks = masks.float() | |
masks = F.interpolate(masks, [224, 224], mode='bilinear') | |
if masks.sum() == masks.numel(): | |
return self._encode(images) | |
device = images.device | |
dtype = images.dtype | |
gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2) | |
vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size | |
vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride | |
mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float() | |
vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) | |
vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size) | |
vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1) | |
import types | |
def customized_embedding_forward(self, pixel_values): | |
batch_size = pixel_values.shape[0] | |
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] | |
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | |
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | |
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | |
embeddings = embeddings + self.position_embedding(self.position_ids) | |
embeddings = embeddings*vtoken_mask.to(embeddings.dtype) | |
return embeddings | |
old_forward = self.model.vision_model.embeddings.forward | |
self.model.vision_model.embeddings.forward = types.MethodType( | |
customized_embedding_forward, self.model.vision_model.embeddings) | |
z = self._encode(images) | |
self.model.vision_model.embeddings.forward = old_forward | |
z = z * vtoken_mask.to(dtype) | |
return z | |
# def _encode_wmask(self, images, masks): | |
# assert isinstance(masks, torch.Tensor) | |
# assert (len(masks.shape)==4) and (masks.shape[1]==1) | |
# masks = torch.clamp(masks, 0, 1) | |
# masks = masks.float() | |
# masks = F.interpolate(masks, [224, 224], mode='bilinear') | |
# if masks.sum() == masks.numel(): | |
# return self._encode(images) | |
# device = images.device | |
# dtype = images.dtype | |
# vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size | |
# vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride | |
# mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float() | |
# vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) | |
# vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size) | |
# z = self._encode(images) | |
# z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype) | |
# z[:, 0, :] = 0 | |
# return z | |
def encode(self, images, masks=None): | |
if masks is None: | |
return self._encode(images) | |
else: | |
return self._encode_wmask(images, masks) | |
class CLIPImageContextEncoderPA(CLIPImageContextEncoder): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
import types | |
def customized_embedding_forward(self, pixel_values): | |
batch_size = pixel_values.shape[0] | |
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] | |
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | |
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | |
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | |
pembeddings = self.position_embedding(self.position_ids) | |
pembeddings = torch.cat([ | |
pembeddings[:, 0:1], | |
pembeddings[:, 1: ].mean(dim=1, keepdim=True).repeat(1, 256, 1)], dim=1) | |
embeddings = embeddings + pembeddings | |
return embeddings | |
self.model.vision_model.embeddings.forward = types.MethodType( | |
customized_embedding_forward, self.model.vision_model.embeddings) | |
############## | |
# from sd2.0 # | |
############## | |
import open_clip | |
import torch.nn.functional as F | |
class FrozenOpenCLIPTextEmbedderSDv2(AbstractEncoder): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
LAYERS = [ | |
#"pooled", | |
"last", | |
"penultimate" | |
] | |
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, | |
freeze=True, layer="last"): | |
super().__init__() | |
assert layer in self.LAYERS | |
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) | |
del model.visual | |
self.model = model | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
if self.layer == "last": | |
self.layer_idx = 0 | |
elif self.layer == "penultimate": | |
self.layer_idx = 1 | |
else: | |
raise NotImplementedError() | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
tokens = open_clip.tokenize(text) | |
z = self.encode_with_transformer(tokens.to(self.device)) | |
return z | |
def encode_with_transformer(self, text): | |
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.model.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.model.ln_final(x) | |
return x | |
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): | |
for i, r in enumerate(self.model.transformer.resblocks): | |
if i == len(self.model.transformer.resblocks) - self.layer_idx: | |
break | |
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): | |
x = checkpoint(r, x, attn_mask) | |
else: | |
x = r(x, attn_mask=attn_mask) | |
return x | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPTextEmbedder(AbstractEncoder): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
def __init__(self, | |
arch="ViT-H-14", | |
version="laion2b_s32b_b79k", | |
max_length=77, | |
freeze=True,): | |
super().__init__() | |
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) | |
del model.visual | |
self.model = model | |
self.max_length = max_length | |
self.device = 'cpu' | |
if freeze: | |
self.freeze() | |
def to(self, device): | |
self.device = device | |
super().to(device) | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
self.device = self.model.ln_final.weight.device # urgly trick | |
tokens = open_clip.tokenize(text) | |
z = self.encode_with_transformer(tokens.to(self.device)) | |
return z | |
def encode_with_transformer(self, text): | |
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
x = x + self.model.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.model.transformer(x, attn_mask=self.model.attn_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.model.ln_final(x) | |
x_pool = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection | |
# x_pool_debug = F.normalize(x_pool, dim=-1) | |
x = x @ self.model.text_projection | |
x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1) | |
return x | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPImageEmbedder(AbstractEncoder): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
def __init__(self, | |
arch="ViT-H-14", | |
version="laion2b_s32b_b79k", | |
freeze=True,): | |
super().__init__() | |
model, _, preprocess = open_clip.create_model_and_transforms( | |
arch, device=torch.device('cpu'), pretrained=version) | |
self.model = model.visual | |
self.device = 'cpu' | |
import torchvision.transforms as tvtrans | |
# we only need resize & normalization | |
preprocess.transforms[0].size = [224, 224] # make it more precise | |
self.preprocess = tvtrans.Compose([ | |
preprocess.transforms[0], | |
preprocess.transforms[4],]) | |
if freeze: | |
self.freeze() | |
def to(self, device): | |
self.device = device | |
super().to(device) | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, image): | |
z = self.preprocess(image) | |
z = self.encode_with_transformer(z) | |
return z | |
def encode_with_transformer(self, image): | |
x = self.model.conv1(image) | |
x = x.reshape(x.shape[0], x.shape[1], -1) | |
x = x.permute(0, 2, 1) | |
x = torch.cat([ | |
self.model.class_embedding.to(x.dtype) | |
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
x], dim=1) | |
x = x + self.model.positional_embedding.to(x.dtype) | |
x = self.model.ln_pre(x) | |
x = x.permute(1, 0, 2) | |
x = self.model.transformer(x) | |
x = x.permute(1, 0, 2) | |
x = self.model.ln_post(x) | |
if self.model.proj is not None: | |
x = x @ self.model.proj | |
x_pool = x[:, 0, :] | |
# x_pool_debug = self.model(image) | |
# x_pooln_debug = F.normalize(x_pool_debug, dim=-1) | |
x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1) | |
return x | |
def _encode(self, image): | |
return self(image) | |
def _encode_wmask(self, images, masks): | |
z = self._encode(images) | |
device = z.device | |
vtoken_kernel_size = self.model.conv1.kernel_size | |
vtoken_stride = self.model.conv1.stride | |
mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, dtype=z.dtype, requires_grad=False) | |
mask_kernal /= np.prod(vtoken_kernel_size) | |
assert isinstance(masks, torch.Tensor) | |
assert (len(masks.shape)==4) and (masks.shape[1]==1) | |
masks = torch.clamp(masks, 0, 1) | |
masks = F.interpolate(masks, [224, 224], mode='bilinear') | |
vtoken_mask = torch.nn.functional.conv2d(1-masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2) | |
z[:, 1:, :] = z[:, 1:, :] * vtoken_mask | |
z[:, 0, :] = 0 | |
return z | |
def encode(self, images, masks=None): | |
if masks is None: | |
return self._encode(images) | |
else: | |
return self._encode_wmask(images, masks) | |
############################ | |
# def customized tokenizer # | |
############################ | |
from open_clip import SimpleTokenizer | |
class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV1(FrozenOpenCLIPTextEmbedderSDv2): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
def __init__(self, customized_tokens, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
if isinstance(customized_tokens, str): | |
customized_tokens = [customized_tokens] | |
self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens) | |
self.num_regular_tokens = self.model.token_embedding.weight.shape[0] | |
self.embedding_dim = self.model.ln_final.weight.shape[0] | |
self.customized_token_embedding = nn.Embedding( | |
len(customized_tokens), embedding_dim=self.embedding_dim) | |
nn.init.normal_(self.customized_token_embedding.weight, std=0.02) | |
def tokenize(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = self.tokenizer.encoder["<start_of_text>"] | |
eot_token = self.tokenizer.encoder["<end_of_text>"] | |
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] | |
maxn = self.num_regular_tokens | |
regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens] | |
token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens] | |
customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens] | |
return regular_tokens, customized_tokens, token_mask | |
def pad_to_length(self, tokens, context_length=77, eot_token=None): | |
result = torch.zeros(len(tokens), context_length, dtype=torch.long) | |
eot_token = self.tokenizer.encoder["<end_of_text>"] if eot_token is None else eot_token | |
for i, tokens in enumerate(tokens): | |
if len(tokens) > context_length: | |
tokens = tokens[:context_length] # Truncate | |
tokens[-1] = eot_token | |
result[i, :len(tokens)] = torch.tensor(tokens) | |
return result | |
def forward(self, text): | |
self.device = self.model.ln_final.weight.device # urgly trick | |
regular_tokens, customized_tokens, token_mask = self.tokenize(text) | |
regular_tokens = self.pad_to_length(regular_tokens).to(self.device) | |
customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device) | |
token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device) | |
z0 = self.encode_with_transformer(regular_tokens) | |
z1 = self.customized_token_embedding(customized_tokens) | |
token_mask = token_mask[:, :, None].type(z0.dtype) | |
z = z0 * (1-token_mask) + z1 * token_mask | |
return z | |
class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2(FrozenOpenCLIPTextEmbedderSDv2): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
def __init__(self, customized_tokens, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
if isinstance(customized_tokens, str): | |
customized_tokens = [customized_tokens] | |
self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens) | |
self.num_regular_tokens = self.model.token_embedding.weight.shape[0] | |
self.embedding_dim = self.model.token_embedding.weight.shape[1] | |
self.customized_token_embedding = nn.Embedding( | |
len(customized_tokens), embedding_dim=self.embedding_dim) | |
nn.init.normal_(self.customized_token_embedding.weight, std=0.02) | |
def tokenize(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = self.tokenizer.encoder["<start_of_text>"] | |
eot_token = self.tokenizer.encoder["<end_of_text>"] | |
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] | |
maxn = self.num_regular_tokens | |
regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens] | |
token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens] | |
customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens] | |
return regular_tokens, customized_tokens, token_mask | |
def pad_to_length(self, tokens, context_length=77, eot_token=None): | |
result = torch.zeros(len(tokens), context_length, dtype=torch.long) | |
eot_token = self.tokenizer.encoder["<end_of_text>"] if eot_token is None else eot_token | |
for i, tokens in enumerate(tokens): | |
if len(tokens) > context_length: | |
tokens = tokens[:context_length] # Truncate | |
tokens[-1] = eot_token | |
result[i, :len(tokens)] = torch.tensor(tokens) | |
return result | |
def forward(self, text): | |
self.device = self.model.token_embedding.weight.device # urgly trick | |
regular_tokens, customized_tokens, token_mask = self.tokenize(text) | |
regular_tokens = self.pad_to_length(regular_tokens).to(self.device) | |
customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device) | |
token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device) | |
z = self.encode_with_transformer(regular_tokens, customized_tokens, token_mask) | |
return z | |
def encode_with_transformer(self, token, customized_token, token_mask): | |
x0 = self.model.token_embedding(token) | |
x1 = self.customized_token_embedding(customized_token) | |
token_mask = token_mask[:, :, None].type(x0.dtype) | |
x = x0 * (1-token_mask) + x1 * token_mask | |
x = x + self.model.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.model.ln_final(x) | |
return x | |
class ln_freezed_temp(nn.LayerNorm): | |
def forward(self, x): | |
self.weight.requires_grad = False | |
self.bias.requires_grad = False | |
return super().forward(x) | |
class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV3(FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
def __init__(self, customized_tokens, texpand=4, lora_rank=None, lora_bias_trainable=True, *args, **kwargs): | |
super().__init__(customized_tokens, *args, **kwargs) | |
if isinstance(customized_tokens, str): | |
customized_tokens = [customized_tokens] | |
self.texpand = texpand | |
self.customized_token_embedding = nn.Embedding( | |
len(customized_tokens)*texpand, embedding_dim=self.embedding_dim) | |
nn.init.normal_(self.customized_token_embedding.weight, std=0.02) | |
if lora_rank is not None: | |
from .lora import freeze_param, freeze_module, to_lora | |
def convert_resattnblock(module): | |
module.ln_1.__class__ = ln_freezed_temp | |
# freeze_module(module.ln_1) | |
module.attn = to_lora(module.attn, lora_rank, lora_bias_trainable) | |
module.ln_2.__class__ = ln_freezed_temp | |
# freeze_module(module.ln_2) | |
module.mlp.c_fc = to_lora(module.mlp.c_fc, lora_rank, lora_bias_trainable) | |
module.mlp.c_proj = to_lora(module.mlp.c_proj, lora_rank, lora_bias_trainable) | |
freeze_param(self.model, 'positional_embedding') | |
freeze_param(self.model, 'text_projection') | |
freeze_param(self.model, 'logit_scale') | |
for idx, resattnblock in enumerate(self.model.transformer.resblocks): | |
convert_resattnblock(resattnblock) | |
freeze_module(self.model.token_embedding) | |
self.model.ln_final.__class__ = ln_freezed_temp | |
# freeze_module(self.model.ln_final) | |
def tokenize(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
sot_token = self.tokenizer.encoder["<start_of_text>"] | |
eot_token = self.tokenizer.encoder["<end_of_text>"] | |
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] | |
maxn = self.num_regular_tokens | |
regular_tokens = [[[ti] if ti < maxn else [0]*self.texpand for ti in tokens] for tokens in all_tokens] | |
token_mask = [[[ 0] if ti < maxn else [1]*self.texpand for ti in tokens] for tokens in all_tokens] | |
custom_tokens = [[[ 0] if ti < maxn else [ | |
(ti-maxn)*self.texpand+ii for ii in range(self.texpand)] | |
for ti in tokens] for tokens in all_tokens] | |
from itertools import chain | |
regular_tokens = [[i for i in chain(*tokens)] for tokens in regular_tokens] | |
token_mask = [[i for i in chain(*tokens)] for tokens in token_mask] | |
custom_tokens = [[i for i in chain(*tokens)] for tokens in custom_tokens] | |
return regular_tokens, custom_tokens, token_mask | |
################### | |
# clip expandable # | |
################### | |
class CLIPTextSD1CE(nn.Module): | |
def __init__( | |
self, | |
replace_info="text|elon musk", | |
version="openai/clip-vit-large-patch14", | |
max_length=77): | |
super().__init__() | |
self.name = 'clip_text_sdv1_customized_embedding' | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.transformer = CLIPTextModel.from_pretrained(version) | |
self.reset_replace_info(replace_info) | |
self.max_length = max_length | |
self.special_token = "<new_token>" | |
def reset_replace_info(self, replace_info): | |
rtype, rpara = replace_info.split("|") | |
self.replace_type = rtype | |
if rtype == "token_embedding": | |
ce_num = int(rpara) | |
ce_dim = self.transformer.text_model.embeddings.token_embedding.weight.size(1) | |
self.cembedding = nn.Embedding(ce_num, ce_dim) | |
self.cembedding = self.cembedding.to(self.get_device()) | |
elif rtype == "context_embedding": | |
ce_num = int(rpara) | |
ce_dim = self.transformer.text_model.encoder.layers[-1].layer_norm2.weight.size(0) | |
self.cembedding = nn.Embedding(ce_num, ce_dim) | |
self.cembedding = self.cembedding.to(self.get_device()) | |
else: | |
assert rtype=="text" | |
self.replace_type = "text" | |
self.replace_string = rpara | |
self.cembedding = None | |
def get_device(self): | |
return self.transformer.text_model.embeddings.token_embedding.weight.device | |
def position_to_mask(self, tokens, positions): | |
mask = torch.zeros_like(tokens) | |
for idxb, idxs, idxe in zip(*positions): | |
mask[idxb, idxs:idxe] = 1 | |
return mask | |
def forward(self, text): | |
tokens, positions = self.tokenize(text) | |
mask = self.position_to_mask(tokens, positions) | |
max_token_n = tokens.size(1) | |
positional_ids = torch.arange(max_token_n)[None].to(self.get_device()) | |
if self.replace_what == 'token_embedding': | |
cembeds = self.cembedding(tokens * mask) | |
def embedding_customized_forward( | |
self, input_ids=None, position_ids=None, inputs_embeds=None,): | |
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | |
if position_ids is None: | |
position_ids = self.position_ids[:, :seq_length] | |
if inputs_embeds is None: | |
inputs_embeds = self.token_embedding(input_ids) | |
inputs_embeds = inputs_embeds * (1-mask.float())[:, :, None] | |
inputs_embeds = inputs_embeds + cembeds | |
position_embeddings = self.position_embedding(position_ids) | |
embeddings = inputs_embeds + position_embeddings | |
return embeddings | |
import types | |
self.transformer.text_model.embeddings.forward = types.MethodType( | |
embedding_customized_forward, self.transformer.text_model.embeddings) | |
else: | |
# TODO: Implement | |
assert False | |
outputs = self.transformer( | |
input_ids=tokens, | |
position_ids=positional_ids, ) | |
z = outputs.last_hidden_state | |
return z | |
def encode(self, text): | |
return self(text) | |
def tokenize(self, text): | |
if isinstance(text, str): | |
text = [text] | |
bos_special_text = "<|startoftext|>" | |
text = [ti.replace(self.special_token, bos_special_text) for ti in text] | |
batch_encoding = self.tokenizer( | |
text, truncation=True, max_length=self.max_length, return_length=True, | |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
tokens = batch_encoding["input_ids"] | |
bosid = tokens[0, 0] | |
eosid = tokens[0, -1] | |
bs, maxn = tokens.shape | |
if self.replace_what in ['token_embedding', 'context_embedding']: | |
newtokens = [] | |
ce_num = self.cembedding.weight.size(0) | |
idxi = []; idxstart = []; idxend = []; | |
for idxii, tokeni in enumerate(tokens): | |
newtokeni = [] | |
idxjj = 0 | |
for ii, tokenii in enumerate(tokeni): | |
if (tokenii == bosid) and (ii != 0): | |
newtokeni.extend([i for i in range(ce_num)]) | |
idxi.append(idxii); idxstart.append(idxjj); | |
idxjj += ce_num | |
idxjj_record = idxjj if idxjj<=maxn-1 else maxn-1 | |
idxend.append(idxjj_record); | |
else: | |
newtokeni.extend([tokenii]) | |
idxjj += 1 | |
newtokeni = newtokeni[:maxn] | |
newtokeni[-1] = eosid | |
newtokens.append(newtokeni) | |
return torch.LongTensor(newtokens).to(self.get_device()), (idxi, idxstart, idxend) | |
else: | |
# TODO: Implement | |
assert False | |