|
import torch.nn as nn |
|
from transformers import CLIPVisionModel |
|
from .xf import LayerNorm, Transformer |
|
|
|
class AbstractEncoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def encode(self, *args, **kwargs): |
|
raise NotImplementedError |
|
|
|
class FrozenCLIPImageEmbedder(AbstractEncoder): |
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" |
|
def __init__(self, version="openai/clip-vit-large-patch14"): |
|
super().__init__() |
|
self.transformer = CLIPVisionModel.from_pretrained(version) |
|
self.final_ln = LayerNorm(1024) |
|
self.mapper = Transformer( |
|
1, |
|
1024, |
|
5, |
|
1, |
|
) |
|
|
|
self.freeze() |
|
|
|
def freeze(self): |
|
self.transformer = self.transformer.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
for param in self.mapper.parameters(): |
|
param.requires_grad = True |
|
for param in self.final_ln.parameters(): |
|
param.requires_grad = True |
|
|
|
def forward(self, image): |
|
outputs = self.transformer(pixel_values=image) |
|
z = outputs.pooler_output |
|
z = z.unsqueeze(1) |
|
z = self.mapper(z) |
|
z = self.final_ln(z) |
|
return z |
|
|
|
def encode(self, image): |
|
if isinstance(image, list): |
|
image = image[0] |
|
return self(image) |
|
|
|
if __name__ == "__main__": |
|
from ldm.util import count_params |
|
model = FrozenCLIPImageEmbedder() |
|
count_params(model, verbose=True) |