3v324v23's picture
code pushed
515f781
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from lib.model_zoo.common.get_model import register
symbol = 'clip'
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
from transformers import CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
@register('clip_text_context_encoder_sdv1')
class CLIPTextContextEncoderSDv1(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
with torch.no_grad():
batch_encoding = self.tokenizer(
text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
max_token_n = self.transformer.text_model.embeddings.position_ids.shape[1]
positional_ids = torch.arange(max_token_n)[None].to(self.device)
outputs = self.transformer(
input_ids=tokens,
position_ids=positional_ids, )
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
#############################
# copyed from justin's code #
#############################
@register('clip_image_context_encoder_justin')
class CLIPImageContextEncoderJustin(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
from . import clip_justin
self.model, _ = clip_justin.load(name=model, device=device, jit=jit)
self.device = device
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
# I didn't call this originally, but seems like it was frozen anyway
self.freeze()
def freeze(self):
self.transformer = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def preprocess(self, x):
import kornia
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x)).float()
def encode(self, im):
return self(im).unsqueeze(1)
###############
# for vd next #
###############
from transformers import CLIPModel
@register('clip_text_context_encoder')
class CLIPTextContextEncoder(AbstractEncoder):
def __init__(self,
version="openai/clip-vit-large-patch14",
max_length=77,
fp16=False, ):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.model = CLIPModel.from_pretrained(version)
self.max_length = max_length
self.fp16 = fp16
self.freeze()
def get_device(self):
# A trick to get device
return self.model.text_projection.weight.device
def freeze(self):
self.model = self.model.eval()
self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def encode(self, text):
batch_encoding = self.tokenizer(
text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.get_device())
outputs = self.model.text_model(input_ids=tokens)
z = self.model.text_projection(outputs.last_hidden_state)
z_pooled = self.model.text_projection(outputs.pooler_output)
z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
return z
from transformers import CLIPProcessor
@register('clip_image_context_encoder')
class CLIPImageContextEncoder(AbstractEncoder):
def __init__(self,
version="openai/clip-vit-large-patch14",
fp16=False, ):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.processor = CLIPProcessor.from_pretrained(version)
self.model = CLIPModel.from_pretrained(version)
self.fp16 = fp16
self.freeze()
def get_device(self):
# A trick to get device
return self.model.text_projection.weight.device
def freeze(self):
self.model = self.model.eval()
self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def _encode(self, images):
if isinstance(images, torch.Tensor):
import torchvision.transforms as tvtrans
images = [tvtrans.ToPILImage()(i) for i in images]
inputs = self.processor(images=images, return_tensors="pt")
pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
pixels = pixels.to(self.get_device())
outputs = self.model.vision_model(pixel_values=pixels)
z = outputs.last_hidden_state
z = self.model.vision_model.post_layernorm(z)
z = self.model.visual_projection(z)
z_pooled = z[:, 0:1]
z = z / torch.norm(z_pooled, dim=-1, keepdim=True)
return z
@torch.no_grad()
def _encode_wmask(self, images, masks):
assert isinstance(masks, torch.Tensor)
assert (len(masks.shape)==4) and (masks.shape[1]==1)
masks = torch.clamp(masks, 0, 1)
masked_images = images*masks
masks = masks.float()
masks = F.interpolate(masks, [224, 224], mode='bilinear')
if masks.sum() == masks.numel():
return self._encode(images)
device = images.device
dtype = images.dtype
gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2)
vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1)
import types
def customized_embedding_forward(self, pixel_values):
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
embeddings = embeddings*vtoken_mask.to(embeddings.dtype)
return embeddings
old_forward = self.model.vision_model.embeddings.forward
self.model.vision_model.embeddings.forward = types.MethodType(
customized_embedding_forward, self.model.vision_model.embeddings)
z = self._encode(images)
self.model.vision_model.embeddings.forward = old_forward
z = z * vtoken_mask.to(dtype)
return z
# def _encode_wmask(self, images, masks):
# assert isinstance(masks, torch.Tensor)
# assert (len(masks.shape)==4) and (masks.shape[1]==1)
# masks = torch.clamp(masks, 0, 1)
# masks = masks.float()
# masks = F.interpolate(masks, [224, 224], mode='bilinear')
# if masks.sum() == masks.numel():
# return self._encode(images)
# device = images.device
# dtype = images.dtype
# vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
# vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
# mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
# vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
# vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
# z = self._encode(images)
# z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype)
# z[:, 0, :] = 0
# return z
def encode(self, images, masks=None):
if masks is None:
return self._encode(images)
else:
return self._encode_wmask(images, masks)
@register('clip_image_context_encoder_position_agnostic')
class CLIPImageContextEncoderPA(CLIPImageContextEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
import types
def customized_embedding_forward(self, pixel_values):
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
pembeddings = self.position_embedding(self.position_ids)
pembeddings = torch.cat([
pembeddings[:, 0:1],
pembeddings[:, 1: ].mean(dim=1, keepdim=True).repeat(1, 256, 1)], dim=1)
embeddings = embeddings + pembeddings
return embeddings
self.model.vision_model.embeddings.forward = types.MethodType(
customized_embedding_forward, self.model.vision_model.embeddings)
##############
# from sd2.0 #
##############
import open_clip
import torch.nn.functional as F
@register('openclip_text_context_encoder_sdv2')
class FrozenOpenCLIPTextEmbedderSDv2(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
@register('openclip_text_context_encoder')
class FrozenOpenCLIPTextEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
max_length=77,
freeze=True,):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.max_length = max_length
self.device = 'cpu'
if freeze:
self.freeze()
def to(self, device):
self.device = device
super().to(device)
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
self.device = self.model.ln_final.weight.device # urgly trick
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.model.transformer(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
x_pool = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
# x_pool_debug = F.normalize(x_pool, dim=-1)
x = x @ self.model.text_projection
x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1)
return x
def encode(self, text):
return self(text)
@register('openclip_image_context_encoder')
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
freeze=True,):
super().__init__()
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=version)
self.model = model.visual
self.device = 'cpu'
import torchvision.transforms as tvtrans
# we only need resize & normalization
preprocess.transforms[0].size = [224, 224] # make it more precise
self.preprocess = tvtrans.Compose([
preprocess.transforms[0],
preprocess.transforms[4],])
if freeze:
self.freeze()
def to(self, device):
self.device = device
super().to(device)
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image):
z = self.preprocess(image)
z = self.encode_with_transformer(z)
return z
def encode_with_transformer(self, image):
x = self.model.conv1(image)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
x = torch.cat([
self.model.class_embedding.to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1)
x = x + self.model.positional_embedding.to(x.dtype)
x = self.model.ln_pre(x)
x = x.permute(1, 0, 2)
x = self.model.transformer(x)
x = x.permute(1, 0, 2)
x = self.model.ln_post(x)
if self.model.proj is not None:
x = x @ self.model.proj
x_pool = x[:, 0, :]
# x_pool_debug = self.model(image)
# x_pooln_debug = F.normalize(x_pool_debug, dim=-1)
x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1)
return x
def _encode(self, image):
return self(image)
def _encode_wmask(self, images, masks):
z = self._encode(images)
device = z.device
vtoken_kernel_size = self.model.conv1.kernel_size
vtoken_stride = self.model.conv1.stride
mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, dtype=z.dtype, requires_grad=False)
mask_kernal /= np.prod(vtoken_kernel_size)
assert isinstance(masks, torch.Tensor)
assert (len(masks.shape)==4) and (masks.shape[1]==1)
masks = torch.clamp(masks, 0, 1)
masks = F.interpolate(masks, [224, 224], mode='bilinear')
vtoken_mask = torch.nn.functional.conv2d(1-masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
z[:, 1:, :] = z[:, 1:, :] * vtoken_mask
z[:, 0, :] = 0
return z
def encode(self, images, masks=None):
if masks is None:
return self._encode(images)
else:
return self._encode_wmask(images, masks)
############################
# def customized tokenizer #
############################
from open_clip import SimpleTokenizer
@register('openclip_text_context_encoder_sdv2_customized_tokenizer_v1')
class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV1(FrozenOpenCLIPTextEmbedderSDv2):
"""
Uses the OpenCLIP transformer encoder for text
"""
def __init__(self, customized_tokens, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(customized_tokens, str):
customized_tokens = [customized_tokens]
self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens)
self.num_regular_tokens = self.model.token_embedding.weight.shape[0]
self.embedding_dim = self.model.ln_final.weight.shape[0]
self.customized_token_embedding = nn.Embedding(
len(customized_tokens), embedding_dim=self.embedding_dim)
nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
def tokenize(self, texts):
if isinstance(texts, str):
texts = [texts]
sot_token = self.tokenizer.encoder["<start_of_text>"]
eot_token = self.tokenizer.encoder["<end_of_text>"]
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
maxn = self.num_regular_tokens
regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens]
token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens]
customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens]
return regular_tokens, customized_tokens, token_mask
def pad_to_length(self, tokens, context_length=77, eot_token=None):
result = torch.zeros(len(tokens), context_length, dtype=torch.long)
eot_token = self.tokenizer.encoder["<end_of_text>"] if eot_token is None else eot_token
for i, tokens in enumerate(tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def forward(self, text):
self.device = self.model.ln_final.weight.device # urgly trick
regular_tokens, customized_tokens, token_mask = self.tokenize(text)
regular_tokens = self.pad_to_length(regular_tokens).to(self.device)
customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device)
token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device)
z0 = self.encode_with_transformer(regular_tokens)
z1 = self.customized_token_embedding(customized_tokens)
token_mask = token_mask[:, :, None].type(z0.dtype)
z = z0 * (1-token_mask) + z1 * token_mask
return z
@register('openclip_text_context_encoder_sdv2_customized_tokenizer_v2')
class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2(FrozenOpenCLIPTextEmbedderSDv2):
"""
Uses the OpenCLIP transformer encoder for text
"""
def __init__(self, customized_tokens, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(customized_tokens, str):
customized_tokens = [customized_tokens]
self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens)
self.num_regular_tokens = self.model.token_embedding.weight.shape[0]
self.embedding_dim = self.model.token_embedding.weight.shape[1]
self.customized_token_embedding = nn.Embedding(
len(customized_tokens), embedding_dim=self.embedding_dim)
nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
def tokenize(self, texts):
if isinstance(texts, str):
texts = [texts]
sot_token = self.tokenizer.encoder["<start_of_text>"]
eot_token = self.tokenizer.encoder["<end_of_text>"]
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
maxn = self.num_regular_tokens
regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens]
token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens]
customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens]
return regular_tokens, customized_tokens, token_mask
def pad_to_length(self, tokens, context_length=77, eot_token=None):
result = torch.zeros(len(tokens), context_length, dtype=torch.long)
eot_token = self.tokenizer.encoder["<end_of_text>"] if eot_token is None else eot_token
for i, tokens in enumerate(tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def forward(self, text):
self.device = self.model.token_embedding.weight.device # urgly trick
regular_tokens, customized_tokens, token_mask = self.tokenize(text)
regular_tokens = self.pad_to_length(regular_tokens).to(self.device)
customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device)
token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device)
z = self.encode_with_transformer(regular_tokens, customized_tokens, token_mask)
return z
def encode_with_transformer(self, token, customized_token, token_mask):
x0 = self.model.token_embedding(token)
x1 = self.customized_token_embedding(customized_token)
token_mask = token_mask[:, :, None].type(x0.dtype)
x = x0 * (1-token_mask) + x1 * token_mask
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
class ln_freezed_temp(nn.LayerNorm):
def forward(self, x):
self.weight.requires_grad = False
self.bias.requires_grad = False
return super().forward(x)
@register('openclip_text_context_encoder_sdv2_customized_tokenizer_v3')
class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV3(FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2):
"""
Uses the OpenCLIP transformer encoder for text
"""
def __init__(self, customized_tokens, texpand=4, lora_rank=None, lora_bias_trainable=True, *args, **kwargs):
super().__init__(customized_tokens, *args, **kwargs)
if isinstance(customized_tokens, str):
customized_tokens = [customized_tokens]
self.texpand = texpand
self.customized_token_embedding = nn.Embedding(
len(customized_tokens)*texpand, embedding_dim=self.embedding_dim)
nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
if lora_rank is not None:
from .lora import freeze_param, freeze_module, to_lora
def convert_resattnblock(module):
module.ln_1.__class__ = ln_freezed_temp
# freeze_module(module.ln_1)
module.attn = to_lora(module.attn, lora_rank, lora_bias_trainable)
module.ln_2.__class__ = ln_freezed_temp
# freeze_module(module.ln_2)
module.mlp.c_fc = to_lora(module.mlp.c_fc, lora_rank, lora_bias_trainable)
module.mlp.c_proj = to_lora(module.mlp.c_proj, lora_rank, lora_bias_trainable)
freeze_param(self.model, 'positional_embedding')
freeze_param(self.model, 'text_projection')
freeze_param(self.model, 'logit_scale')
for idx, resattnblock in enumerate(self.model.transformer.resblocks):
convert_resattnblock(resattnblock)
freeze_module(self.model.token_embedding)
self.model.ln_final.__class__ = ln_freezed_temp
# freeze_module(self.model.ln_final)
def tokenize(self, texts):
if isinstance(texts, str):
texts = [texts]
sot_token = self.tokenizer.encoder["<start_of_text>"]
eot_token = self.tokenizer.encoder["<end_of_text>"]
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
maxn = self.num_regular_tokens
regular_tokens = [[[ti] if ti < maxn else [0]*self.texpand for ti in tokens] for tokens in all_tokens]
token_mask = [[[ 0] if ti < maxn else [1]*self.texpand for ti in tokens] for tokens in all_tokens]
custom_tokens = [[[ 0] if ti < maxn else [
(ti-maxn)*self.texpand+ii for ii in range(self.texpand)]
for ti in tokens] for tokens in all_tokens]
from itertools import chain
regular_tokens = [[i for i in chain(*tokens)] for tokens in regular_tokens]
token_mask = [[i for i in chain(*tokens)] for tokens in token_mask]
custom_tokens = [[i for i in chain(*tokens)] for tokens in custom_tokens]
return regular_tokens, custom_tokens, token_mask
###################
# clip expandable #
###################
@register('clip_text_sdv1_customized_embedding')
class CLIPTextSD1CE(nn.Module):
def __init__(
self,
replace_info="text|elon musk",
version="openai/clip-vit-large-patch14",
max_length=77):
super().__init__()
self.name = 'clip_text_sdv1_customized_embedding'
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.reset_replace_info(replace_info)
self.max_length = max_length
self.special_token = "<new_token>"
def reset_replace_info(self, replace_info):
rtype, rpara = replace_info.split("|")
self.replace_type = rtype
if rtype == "token_embedding":
ce_num = int(rpara)
ce_dim = self.transformer.text_model.embeddings.token_embedding.weight.size(1)
self.cembedding = nn.Embedding(ce_num, ce_dim)
self.cembedding = self.cembedding.to(self.get_device())
elif rtype == "context_embedding":
ce_num = int(rpara)
ce_dim = self.transformer.text_model.encoder.layers[-1].layer_norm2.weight.size(0)
self.cembedding = nn.Embedding(ce_num, ce_dim)
self.cembedding = self.cembedding.to(self.get_device())
else:
assert rtype=="text"
self.replace_type = "text"
self.replace_string = rpara
self.cembedding = None
def get_device(self):
return self.transformer.text_model.embeddings.token_embedding.weight.device
def position_to_mask(self, tokens, positions):
mask = torch.zeros_like(tokens)
for idxb, idxs, idxe in zip(*positions):
mask[idxb, idxs:idxe] = 1
return mask
def forward(self, text):
tokens, positions = self.tokenize(text)
mask = self.position_to_mask(tokens, positions)
max_token_n = tokens.size(1)
positional_ids = torch.arange(max_token_n)[None].to(self.get_device())
if self.replace_what == 'token_embedding':
cembeds = self.cembedding(tokens * mask)
def embedding_customized_forward(
self, input_ids=None, position_ids=None, inputs_embeds=None,):
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
inputs_embeds = inputs_embeds * (1-mask.float())[:, :, None]
inputs_embeds = inputs_embeds + cembeds
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
import types
self.transformer.text_model.embeddings.forward = types.MethodType(
embedding_customized_forward, self.transformer.text_model.embeddings)
else:
# TODO: Implement
assert False
outputs = self.transformer(
input_ids=tokens,
position_ids=positional_ids, )
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
@torch.no_grad()
def tokenize(self, text):
if isinstance(text, str):
text = [text]
bos_special_text = "<|startoftext|>"
text = [ti.replace(self.special_token, bos_special_text) for ti in text]
batch_encoding = self.tokenizer(
text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
bosid = tokens[0, 0]
eosid = tokens[0, -1]
bs, maxn = tokens.shape
if self.replace_what in ['token_embedding', 'context_embedding']:
newtokens = []
ce_num = self.cembedding.weight.size(0)
idxi = []; idxstart = []; idxend = [];
for idxii, tokeni in enumerate(tokens):
newtokeni = []
idxjj = 0
for ii, tokenii in enumerate(tokeni):
if (tokenii == bosid) and (ii != 0):
newtokeni.extend([i for i in range(ce_num)])
idxi.append(idxii); idxstart.append(idxjj);
idxjj += ce_num
idxjj_record = idxjj if idxjj<=maxn-1 else maxn-1
idxend.append(idxjj_record);
else:
newtokeni.extend([tokenii])
idxjj += 1
newtokeni = newtokeni[:maxn]
newtokeni[-1] = eosid
newtokens.append(newtokeni)
return torch.LongTensor(newtokens).to(self.get_device()), (idxi, idxstart, idxend)
else:
# TODO: Implement
assert False