# Prepare the models to speed up loading them later import torch from torch import nn, Tensor import os from tqdm import tqdm import json from .utils import load model_name_map = { "ViT-B/16": "vit_b_16", "ViT-L/14": "vit_l_14", } class CLIPTextEncoderTemp(nn.Module): def __init__( self, clip: nn.Module, ) -> None: super().__init__() self.context_length = clip.context_length self.vocab_size = clip.vocab_size self.dtype = clip.dtype self.token_embedding = clip.token_embedding self.positional_embedding = clip.positional_embedding self.transformer = clip.transformer self.ln_final = clip.ln_final self.text_projection = clip.text_projection def forward(self, text: Tensor) -> None: pass def prepare() -> None: print("Preparing CLIP models...") curr_dir = os.path.dirname(os.path.abspath(__file__)) weight_dir = os.path.join(curr_dir, "weights") config_dir = os.path.join(curr_dir, "configs") os.makedirs(weight_dir, exist_ok=True) os.makedirs(config_dir, exist_ok=True) device = torch.device("cpu") for model_name in tqdm(["ViT-B/16", "ViT-L/14"]): model = load(model_name, device=device).to(device) image_encoder = model.visual.to(device) text_encoder = CLIPTextEncoderTemp(model).to(device) torch.save(model.state_dict(), os.path.join(weight_dir, f"clip_{model_name_map[model_name]}.pth")) torch.save(image_encoder.state_dict(), os.path.join(weight_dir, f"clip_image_encoder_{model_name_map[model_name]}.pth")) torch.save(text_encoder.state_dict(), os.path.join(weight_dir, f"clip_text_encoder_{model_name_map[model_name]}.pth")) model_config = { "embed_dim": model.embed_dim, # vision "image_resolution": model.image_resolution, "vision_layers": model.vision_layers, "vision_width": model.vision_width, "vision_patch_size": model.vision_patch_size, # text "context_length": model.context_length, "vocab_size": model.vocab_size, "transformer_width": model.transformer_width, "transformer_heads": model.transformer_heads, "transformer_layers": model.transformer_layers, } image_encoder_config = { "embed_dim": model.embed_dim, "image_resolution": model.image_resolution, "vision_layers": model.vision_layers, "vision_width": model.vision_width, "vision_patch_size": model.vision_patch_size, "vision_heads": model.vision_heads, } text_encoder_config = { "embed_dim": model.embed_dim, "context_length": model.context_length, "vocab_size": model.vocab_size, "transformer_width": model.transformer_width, "transformer_heads": model.transformer_heads, "transformer_layers": model.transformer_layers, } with open(os.path.join(config_dir, f"clip_{model_name_map[model_name]}.json"), "w") as f: json.dump(model_config, f, indent=4) with open(os.path.join(config_dir, f"clip_image_encoder_{model_name_map[model_name]}.json"), "w") as f: json.dump(image_encoder_config, f, indent=4) with open(os.path.join(config_dir, f"clip_text_encoder_{model_name_map[model_name]}.json"), "w") as f: json.dump(text_encoder_config, f, indent=4) print("Done!")