Commit
·
8a00d0d
0
Parent(s):
initial commit
Browse files- .gitignore +1 -0
- __pycache__/text_encoder.cpython-311.pyc +0 -0
- __pycache__/train.cpython-311.pyc +0 -0
- __pycache__/vision_encoder.cpython-311.pyc +0 -0
- _dataset/__pycache__/preprocess_images.cpython-311.pyc +0 -0
- _dataset/preprocess_captions.ipynb +188 -0
- _dataset/preprocess_images.py +79 -0
- demo.ipynb +240 -0
- text_encoder.py +27 -0
- train.py +204 -0
- vision_encoder.py +56 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
checkpoints
|
__pycache__/text_encoder.cpython-311.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
__pycache__/train.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
__pycache__/vision_encoder.cpython-311.pyc
ADDED
|
Binary file (3.05 kB). View file
|
|
|
_dataset/__pycache__/preprocess_images.cpython-311.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
_dataset/preprocess_captions.ipynb
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from collections import defaultdict\n",
|
| 10 |
+
"from transformers import AutoTokenizer\n",
|
| 11 |
+
"from tqdm import tqdm\n",
|
| 12 |
+
"import json\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"def load_and_process_token_file(input_path, tokenizer_name=\"answerdotai/ModernBERT-base\"):\n",
|
| 15 |
+
" captions_dict = defaultdict(list)\n",
|
| 16 |
+
" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n",
|
| 17 |
+
" max_length = 0 # Initialize max length counter\n",
|
| 18 |
+
"\n",
|
| 19 |
+
" # Read and process the token file with tokenization\n",
|
| 20 |
+
" with open(input_path, 'r') as file:\n",
|
| 21 |
+
" for line in tqdm(file, desc=\"Processing Captions\"):\n",
|
| 22 |
+
" image_id, caption = line.strip().split('\\t')\n",
|
| 23 |
+
" jpg_number = image_id.split('.')[0]\n",
|
| 24 |
+
" \n",
|
| 25 |
+
" # Tokenize without padding and truncation to calculate the true length\n",
|
| 26 |
+
" tokens = tokenizer(caption, return_tensors=\"pt\", padding=False, truncation=False)\n",
|
| 27 |
+
" token_ids = tokens['input_ids'].squeeze(0).tolist()\n",
|
| 28 |
+
" \n",
|
| 29 |
+
" # Update max_length based on this tokenized sequence length\n",
|
| 30 |
+
" max_length = max(max_length, len(token_ids))\n",
|
| 31 |
+
" \n",
|
| 32 |
+
" # Tokenize with padding and attention mask (padded to 93 tokens)\n",
|
| 33 |
+
" tokens_padded = tokenizer(caption, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=2**7) # 93 < 2**7\n",
|
| 34 |
+
" token_ids_padded = tokens_padded['input_ids'].squeeze(0).tolist()\n",
|
| 35 |
+
" attention_mask = tokens_padded['attention_mask'].squeeze(0).tolist()\n",
|
| 36 |
+
"\n",
|
| 37 |
+
" # Save both raw caption, tokenized version, and attention mask\n",
|
| 38 |
+
" captions_dict[jpg_number].append({\n",
|
| 39 |
+
" \"text\": caption,\n",
|
| 40 |
+
" \"tokenized\": token_ids_padded,\n",
|
| 41 |
+
" \"attention_mask\": attention_mask\n",
|
| 42 |
+
" })\n",
|
| 43 |
+
"\n",
|
| 44 |
+
" print(f\"Maximum sequence length (before padding): {max_length}\")\n",
|
| 45 |
+
" return captions_dict, max_length\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"# Define the input path and process the file\n",
|
| 48 |
+
"input_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/results_20130124.token'\n",
|
| 49 |
+
"captions_dict, max_length = load_and_process_token_file(input_path)\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# Save the modified dictionary with tokenized captions and attention masks to a JSON file\n",
|
| 52 |
+
"output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n",
|
| 53 |
+
"with open(output_path, 'w') as json_file:\n",
|
| 54 |
+
" json.dump(captions_dict, json_file)\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# Display the maximum token length\n",
|
| 57 |
+
"print(f\"Final maximum token length across dataset: {max_length}\")\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# Display the first few entries to verify the content\n",
|
| 60 |
+
"for jpg, captions in list(captions_dict.items())[:5]:\n",
|
| 61 |
+
" print(f\"{jpg}: {captions}\")"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": null,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"\n",
|
| 71 |
+
"# Save the dictionary to a JSON file\n",
|
| 72 |
+
"output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_dict.json'\n",
|
| 73 |
+
"with open(output_path, 'w') as json_file:\n",
|
| 74 |
+
" json.dump(captions_dict, json_file)\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"print(f\"Captions dictionary saved to {output_path}\")"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"execution_count": null,
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": []
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": 2,
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [],
|
| 91 |
+
"source": [
|
| 92 |
+
"import torch\n",
|
| 93 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 94 |
+
"import os\n",
|
| 95 |
+
"import json\n",
|
| 96 |
+
"import numpy as np\n",
|
| 97 |
+
"import random\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# Vision Caption Dataset\n",
|
| 101 |
+
"class VisionCaptionDataset(torch.utils.data.Dataset):\n",
|
| 102 |
+
" def __init__(self, captions_path, embeddings_dir, normalize=True):\n",
|
| 103 |
+
" with open(captions_path, 'r') as f:\n",
|
| 104 |
+
" self.captions_dict = json.load(f)\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" self.embeddings_dir = embeddings_dir\n",
|
| 107 |
+
" self.image_ids = list(self.captions_dict.keys())\n",
|
| 108 |
+
" self.normalize = normalize\n",
|
| 109 |
+
"\n",
|
| 110 |
+
" def __len__(self):\n",
|
| 111 |
+
" return len(self.image_ids)\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" def __getitem__(self, idx):\n",
|
| 114 |
+
" image_id = self.image_ids[idx]\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" # Randomly select a caption and load the tokenized version\n",
|
| 117 |
+
" caption_entry = random.choice(self.captions_dict[image_id])\n",
|
| 118 |
+
" tokenized_caption = caption_entry[\"tokenized\"]\n",
|
| 119 |
+
" attention_mask = caption_entry[\"attention_mask\"]\n",
|
| 120 |
+
"\n",
|
| 121 |
+
" # Load vision embedding\n",
|
| 122 |
+
" embedding_path = os.path.join(self.embeddings_dir, f\"{image_id}.npy\")\n",
|
| 123 |
+
" embedding = np.load(embedding_path)\n",
|
| 124 |
+
"\n",
|
| 125 |
+
" # Convert vision embedding and tokenized caption to tensors\n",
|
| 126 |
+
" embedding = torch.tensor(embedding, dtype=torch.float32)\n",
|
| 127 |
+
" tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)\n",
|
| 128 |
+
" attention_mask = torch.tensor(attention_mask, dtype=torch.long)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" return embedding, tokenized_caption, attention_mask\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"# Example usage\n",
|
| 133 |
+
"# Paths for dataset\n",
|
| 134 |
+
"captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n",
|
| 135 |
+
"embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"# Initialize the dataset and split it into train/validation sets\n",
|
| 138 |
+
"full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"# Initialize the DataLoaders with `num_workers` and `pin_memory`\n",
|
| 141 |
+
"train_dataloader = DataLoader(full_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)\n"
|
| 142 |
+
]
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"cell_type": "code",
|
| 146 |
+
"execution_count": null,
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"outputs": [],
|
| 149 |
+
"source": [
|
| 150 |
+
"# Verify a batch\n",
|
| 151 |
+
"for batch in train_dataloader:\n",
|
| 152 |
+
" embeddings, captions, attn_mask = batch\n",
|
| 153 |
+
" print(embeddings.shape, len(captions))\n",
|
| 154 |
+
" \n",
|
| 155 |
+
"\n",
|
| 156 |
+
" break"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": null,
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"outputs": [],
|
| 164 |
+
"source": []
|
| 165 |
+
}
|
| 166 |
+
],
|
| 167 |
+
"metadata": {
|
| 168 |
+
"kernelspec": {
|
| 169 |
+
"display_name": "hf-env",
|
| 170 |
+
"language": "python",
|
| 171 |
+
"name": "python3"
|
| 172 |
+
},
|
| 173 |
+
"language_info": {
|
| 174 |
+
"codemirror_mode": {
|
| 175 |
+
"name": "ipython",
|
| 176 |
+
"version": 3
|
| 177 |
+
},
|
| 178 |
+
"file_extension": ".py",
|
| 179 |
+
"mimetype": "text/x-python",
|
| 180 |
+
"name": "python",
|
| 181 |
+
"nbconvert_exporter": "python",
|
| 182 |
+
"pygments_lexer": "ipython3",
|
| 183 |
+
"version": "3.11.11"
|
| 184 |
+
}
|
| 185 |
+
},
|
| 186 |
+
"nbformat": 4,
|
| 187 |
+
"nbformat_minor": 2
|
| 188 |
+
}
|
_dataset/preprocess_images.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from transformers.image_utils import load_image
|
| 9 |
+
import sys
|
| 10 |
+
sys.path.append('..')
|
| 11 |
+
from vision_encoder import ideficsV3
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
class VisionPreprocessor:
|
| 15 |
+
def __init__(self, device=None, param_dtype=torch.float32):
|
| 16 |
+
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
self.param_dtype = param_dtype
|
| 18 |
+
|
| 19 |
+
# Initialize and freeze the vision encoder
|
| 20 |
+
self.vision_encoder = ideficsV3("HuggingFaceTB/SmolVLM-Instruct").eval().to(self.device)
|
| 21 |
+
for param in self.vision_encoder.parameters():
|
| 22 |
+
param.requires_grad = False
|
| 23 |
+
|
| 24 |
+
def load_image(self, image_path):
|
| 25 |
+
"""Load an image using PIL without preprocessing."""
|
| 26 |
+
image = load_image(image_path)
|
| 27 |
+
# Convert to tensor without resizing or additional transformations
|
| 28 |
+
inputs = self.vision_encoder.image_processor(images=[image], return_tensors="pt")
|
| 29 |
+
pixel_values = inputs.pixel_values.to(self.param_dtype).to(self.device)
|
| 30 |
+
return pixel_values
|
| 31 |
+
|
| 32 |
+
def extract_embedding(self, image_tensor):
|
| 33 |
+
"""Extract raw vision embedding."""
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
vision_output = self.vision_encoder(image_tensor)
|
| 36 |
+
|
| 37 |
+
vision_output = vision_output.mean(axis=0)
|
| 38 |
+
|
| 39 |
+
return vision_output
|
| 40 |
+
|
| 41 |
+
def save_embedding(self, vision_output, file_path):
|
| 42 |
+
"""Save the vision output to a numpy file."""
|
| 43 |
+
np.save(file_path, vision_output.cpu().numpy())
|
| 44 |
+
|
| 45 |
+
def process_directory(self, image_paths, output_dir):
|
| 46 |
+
"""Process all images in a directory with a progress bar and save the embeddings."""
|
| 47 |
+
if os.path.exists(output_dir):
|
| 48 |
+
shutil.rmtree(output_dir)
|
| 49 |
+
print(f"Existing directory cleared: {output_dir}")
|
| 50 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# Adding tqdm for progress bar
|
| 53 |
+
for image_path in tqdm(image_paths, desc="Processing Images", unit="image"):
|
| 54 |
+
|
| 55 |
+
# Load and extract features without preprocessing
|
| 56 |
+
image_tensor = self.load_image(image_path)
|
| 57 |
+
vision_output = self.extract_embedding(image_tensor)
|
| 58 |
+
|
| 59 |
+
# Save the output with the same filename but as a .npy
|
| 60 |
+
output_file_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}.npy")
|
| 61 |
+
self.save_embedding(vision_output, output_file_path)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
torch.manual_seed(42)
|
| 66 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 67 |
+
param_dtype = torch.float32
|
| 68 |
+
|
| 69 |
+
# Instantiate the pipeline
|
| 70 |
+
pipeline = VisionPreprocessor(device, param_dtype)
|
| 71 |
+
|
| 72 |
+
# Specify input and output directories
|
| 73 |
+
input_directory = "/mnt/nvme/shared_A/datasets/flickr30k/data/flickr30k-images"
|
| 74 |
+
output_directory = "/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2"
|
| 75 |
+
|
| 76 |
+
image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
| 77 |
+
# Process all images in the input directory
|
| 78 |
+
pipeline.process_directory(image_paths, output_directory)
|
| 79 |
+
print("Processing complete!")
|
demo.ipynb
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Image search with modernBERT"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": 18,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"from _dataset.preprocess_images import *\n",
|
| 17 |
+
"import random"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": null,
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [],
|
| 25 |
+
"source": [
|
| 26 |
+
"\n",
|
| 27 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 28 |
+
"pipeline = VisionPreprocessor(device, param_dtype=torch.float32)\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"num_images = 25\n",
|
| 31 |
+
"input_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/val2017\"\n",
|
| 32 |
+
"image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"# Shuffle and take the first 25 images\n",
|
| 35 |
+
"# random.shuffle(image_paths)\n",
|
| 36 |
+
"image_paths = image_paths[:num_images]\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Print the selected image paths\n",
|
| 39 |
+
"print(\"Selected Image Paths:\")\n",
|
| 40 |
+
"for path in image_paths:\n",
|
| 41 |
+
" print(path)\n"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"import os\n",
|
| 51 |
+
"import shutil\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"# Specify the output directory\n",
|
| 54 |
+
"output_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# Clear the vision embeddings directory if it exists, otherwise create it\n",
|
| 57 |
+
"if os.path.exists(output_directory):\n",
|
| 58 |
+
" shutil.rmtree(output_directory)\n",
|
| 59 |
+
" print(f\"Existing directory cleared: {output_directory}\")\n",
|
| 60 |
+
"os.makedirs(output_directory, exist_ok=True)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"# Process all images in the input directory\n",
|
| 63 |
+
"pipeline.process_directory(image_paths, output_directory)\n",
|
| 64 |
+
"print(\"Image embeddings saved!\")"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": null,
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [],
|
| 72 |
+
"source": [
|
| 73 |
+
"from train import JointNetwork\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"def load_checkpoint_and_prepare_model(checkpoint_path, device=\"cuda\"):\n",
|
| 76 |
+
" \"\"\"Load trained JointNetwork() from checkpoint\"\"\"\n",
|
| 77 |
+
" device = torch.device(device)\n",
|
| 78 |
+
" model = JointNetwork()\n",
|
| 79 |
+
" checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
|
| 80 |
+
" model.load_state_dict(checkpoint['model_state_dict'])\n",
|
| 81 |
+
" model.to(device)\n",
|
| 82 |
+
" model.eval()\n",
|
| 83 |
+
" model.device = device\n",
|
| 84 |
+
" print(f\"Model loaded successfully from {checkpoint_path}.\")\n",
|
| 85 |
+
" return model\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"def get_text_embedding(model, text_prompt):\n",
|
| 88 |
+
" \"\"\"Encode a text prompt to get its embedding using the modernBERT encoder.\"\"\"\n",
|
| 89 |
+
" tokenized_text = model.text_encoder.tokenizer(text_prompt, return_tensors=\"pt\").to(model.device)\n",
|
| 90 |
+
" with torch.no_grad():\n",
|
| 91 |
+
" text_features = model.text_encoder(tokenized_text)\n",
|
| 92 |
+
" text_features = model.text_projector(text_features.mean(dim=1))\n",
|
| 93 |
+
" text_features = F.normalize(text_features, dim=1)\n",
|
| 94 |
+
" return text_features\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"def load_image_embeddings(model, embeddings_dir):\n",
|
| 97 |
+
" \"\"\"Load all precomputed image embeddings from the specified directory.\"\"\"\n",
|
| 98 |
+
" vision_embeddings = []\n",
|
| 99 |
+
" for file in sorted(os.listdir(embeddings_dir)):\n",
|
| 100 |
+
" if file.endswith(\".npy\"):\n",
|
| 101 |
+
" image_encoding = torch.tensor(np.load(os.path.join(embeddings_dir, file)), dtype=torch.float32).to(model.device)\n",
|
| 102 |
+
" vision_pooled = image_encoding.mean(dim=0).unsqueeze(0)\n",
|
| 103 |
+
" vision_embedded = model.vision_projector(vision_pooled)\n",
|
| 104 |
+
" vision_embedded = F.normalize(vision_embedded, dim=1)\n",
|
| 105 |
+
" vision_embeddings.append(vision_embedded)\n",
|
| 106 |
+
" \n",
|
| 107 |
+
" if len(vision_embeddings) == 0:\n",
|
| 108 |
+
" raise ValueError(\"No vision embeddings found in the specified directory.\")\n",
|
| 109 |
+
" print(f\"Vision embeddings loaded successfully from {embeddings_dir}.\")\n",
|
| 110 |
+
" return torch.stack(vision_embeddings).squeeze(1)\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"def compare_text_to_images(text_embedding, vision_embeddings):\n",
|
| 113 |
+
" \"\"\"Compare a text embedding against a batch of image embeddings using cosine similarity.\"\"\"\n",
|
| 114 |
+
" cosine_similarities = torch.matmul(text_embedding, vision_embeddings.T).squeeze(0)\n",
|
| 115 |
+
" similarity_scores = cosine_similarities.cpu().detach().numpy()\n",
|
| 116 |
+
" ranked_indices = similarity_scores.argsort()[::-1] # Sort in descending order\n",
|
| 117 |
+
" return ranked_indices, similarity_scores\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"# Paths and settings\n",
|
| 122 |
+
"checkpoint_path = \"/home/nolan4/projects/hf-contest/checkpoints/model_checkpoint_20250109_102039.pth\"\n",
|
| 123 |
+
"embeddings_dir = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"# Load the model and precomputed vision embeddings\n",
|
| 126 |
+
"model = load_checkpoint_and_prepare_model(checkpoint_path)\n",
|
| 127 |
+
"vision_embeddings = load_image_embeddings(model, embeddings_dir)"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"import matplotlib.pyplot as plt\n",
|
| 137 |
+
"import os\n",
|
| 138 |
+
"from PIL import Image\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"def display_images_from_paths(image_paths, num_images=5):\n",
|
| 141 |
+
"\n",
|
| 142 |
+
" num_images = min(num_images, len(image_paths))\n",
|
| 143 |
+
" if num_images == 0:\n",
|
| 144 |
+
" print(\"No images found in the directory.\")\n",
|
| 145 |
+
" return\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" plt.figure(figsize=(12, 8))\n",
|
| 148 |
+
" for i, image_path in enumerate(image_paths[:num_images]):\n",
|
| 149 |
+
" img = Image.open(image_path)\n",
|
| 150 |
+
" plt.subplot(1, num_images, i + 1)\n",
|
| 151 |
+
" plt.imshow(img)\n",
|
| 152 |
+
" plt.axis('off') \n",
|
| 153 |
+
" plt.title(f\"{os.path.basename(image_path).split('.')[0]}\")\n",
|
| 154 |
+
"\n",
|
| 155 |
+
" plt.tight_layout()\n",
|
| 156 |
+
" plt.show()\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"# Example usage\n",
|
| 159 |
+
"# random.shuffle(image_paths)\n",
|
| 160 |
+
"display_images_from_paths(image_paths, num_images=10)"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": null,
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"outputs": [],
|
| 168 |
+
"source": [
|
| 169 |
+
"# Paths and settings\n",
|
| 170 |
+
"text_prompt = \"cars driving down the road\"\n",
|
| 171 |
+
"# text_prompt = \"stuffed brown teddy bear\"\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"# Load the model and embeddings\n",
|
| 175 |
+
"text_embedding = get_text_embedding(model, text_prompt)\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"# Perform comparison and display results\n",
|
| 178 |
+
"ranked_indices, similarity_scores = compare_text_to_images(text_embedding, vision_embeddings)\n",
|
| 179 |
+
"print(f\"\\nTop 5 Most Similar Images:\")\n",
|
| 180 |
+
"for idx in ranked_indices[:5]:\n",
|
| 181 |
+
" print(f\"Image Index: {idx}, Similarity Score: {similarity_scores[idx]:.4f}\")"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"execution_count": null,
|
| 187 |
+
"metadata": {},
|
| 188 |
+
"outputs": [],
|
| 189 |
+
"source": [
|
| 190 |
+
"# Ensure ranked_indices is converted to a Python list\n",
|
| 191 |
+
"selected_image_paths = [image_paths[idx] for idx in ranked_indices[:10]]\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"# Display the top N ranked images\n",
|
| 194 |
+
"display_images_from_paths(selected_image_paths, num_images=4)"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "code",
|
| 199 |
+
"execution_count": null,
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [],
|
| 202 |
+
"source": []
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": []
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "code",
|
| 213 |
+
"execution_count": null,
|
| 214 |
+
"metadata": {},
|
| 215 |
+
"outputs": [],
|
| 216 |
+
"source": []
|
| 217 |
+
}
|
| 218 |
+
],
|
| 219 |
+
"metadata": {
|
| 220 |
+
"kernelspec": {
|
| 221 |
+
"display_name": "hf-env",
|
| 222 |
+
"language": "python",
|
| 223 |
+
"name": "python3"
|
| 224 |
+
},
|
| 225 |
+
"language_info": {
|
| 226 |
+
"codemirror_mode": {
|
| 227 |
+
"name": "ipython",
|
| 228 |
+
"version": 3
|
| 229 |
+
},
|
| 230 |
+
"file_extension": ".py",
|
| 231 |
+
"mimetype": "text/x-python",
|
| 232 |
+
"name": "python",
|
| 233 |
+
"nbconvert_exporter": "python",
|
| 234 |
+
"pygments_lexer": "ipython3",
|
| 235 |
+
"version": "3.11.11"
|
| 236 |
+
}
|
| 237 |
+
},
|
| 238 |
+
"nbformat": 4,
|
| 239 |
+
"nbformat_minor": 2
|
| 240 |
+
}
|
text_encoder.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, ModernBertModel
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import pdb
|
| 6 |
+
|
| 7 |
+
class modernBERT(nn.Module):
|
| 8 |
+
def __init__(self, model_name="answerdotai/ModernBERT-base"):
|
| 9 |
+
super(modernBERT, self).__init__()
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 11 |
+
self.bert = ModernBertModel.from_pretrained(model_name)
|
| 12 |
+
|
| 13 |
+
def forward(self, inputs):
|
| 14 |
+
# inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
| 15 |
+
outputs = self.bert(**inputs)
|
| 16 |
+
|
| 17 |
+
return outputs.last_hidden_state # logits
|
| 18 |
+
|
| 19 |
+
# Example training loop
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
model = modernBERT("answerdotai/ModernBERT-base")
|
| 22 |
+
|
| 23 |
+
texts = ["Potato's no name for a dog"]
|
| 24 |
+
text_inputs = {"input_ids": model.tokenizer(texts)}
|
| 25 |
+
output = model(text_inputs)
|
| 26 |
+
|
| 27 |
+
print(output[0].shape)
|
train.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import DataLoader, random_split
|
| 5 |
+
from text_encoder import *
|
| 6 |
+
from vision_encoder import *
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
import random
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import datetime
|
| 13 |
+
|
| 14 |
+
# Vision Caption Dataset
|
| 15 |
+
class VisionCaptionDataset(torch.utils.data.Dataset):
|
| 16 |
+
def __init__(self, captions_path, embeddings_dir, normalize=True):
|
| 17 |
+
with open(captions_path, 'r') as f:
|
| 18 |
+
self.captions_dict = json.load(f)
|
| 19 |
+
|
| 20 |
+
self.embeddings_dir = embeddings_dir
|
| 21 |
+
self.image_ids = list(self.captions_dict.keys())
|
| 22 |
+
self.normalize = normalize
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return len(self.image_ids)
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, idx):
|
| 28 |
+
image_id = self.image_ids[idx]
|
| 29 |
+
|
| 30 |
+
caption_entry = random.choice(self.captions_dict[image_id])
|
| 31 |
+
tokenized_caption = caption_entry["tokenized"]
|
| 32 |
+
attention_mask = caption_entry["attention_mask"]
|
| 33 |
+
|
| 34 |
+
embedding_path = os.path.join(self.embeddings_dir, f"{image_id}.npy")
|
| 35 |
+
embedding = np.load(embedding_path)
|
| 36 |
+
|
| 37 |
+
embedding = torch.tensor(embedding, dtype=torch.float32)
|
| 38 |
+
tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)
|
| 39 |
+
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
| 40 |
+
|
| 41 |
+
return embedding, tokenized_caption, attention_mask
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class JointNetwork(nn.Module):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super(JointNetwork, self).__init__()
|
| 47 |
+
|
| 48 |
+
self.text_encoder = modernBERT("answerdotai/ModernBERT-base")
|
| 49 |
+
|
| 50 |
+
for param in self.text_encoder.parameters():
|
| 51 |
+
param.requires_grad = True
|
| 52 |
+
|
| 53 |
+
self.vision_projector = nn.Linear(1152, 512)
|
| 54 |
+
self.text_projector = nn.Linear(768, 512)
|
| 55 |
+
|
| 56 |
+
def forward(self, tokenized_text, image_encoding):
|
| 57 |
+
vision_patch_pooled = image_encoding.mean(dim=1)
|
| 58 |
+
text_output = self.text_encoder(tokenized_text)
|
| 59 |
+
text_pooled = text_output.mean(dim=1)
|
| 60 |
+
|
| 61 |
+
vision_embedded = self.vision_projector(vision_patch_pooled)
|
| 62 |
+
text_embedded = self.text_projector(text_pooled)
|
| 63 |
+
|
| 64 |
+
vision_embedded = F.normalize(vision_embedded, dim=1)
|
| 65 |
+
text_embedded = F.normalize(text_embedded, dim=1)
|
| 66 |
+
|
| 67 |
+
return text_embedded, vision_embedded
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def infoNCE_loss(text_features, vision_features, temperature=0.07):
|
| 71 |
+
text_features = F.normalize(text_features, p=2, dim=-1)
|
| 72 |
+
vision_features = F.normalize(vision_features, p=2, dim=-1)
|
| 73 |
+
|
| 74 |
+
similarity_matrix = torch.matmul(text_features, vision_features.T) / temperature
|
| 75 |
+
batch_size = vision_features.size(0)
|
| 76 |
+
labels = torch.arange(batch_size, device=vision_features.device)
|
| 77 |
+
|
| 78 |
+
loss_text_to_image = F.cross_entropy(similarity_matrix, labels)
|
| 79 |
+
loss_image_to_text = F.cross_entropy(similarity_matrix.T, labels)
|
| 80 |
+
|
| 81 |
+
return (loss_text_to_image + loss_image_to_text) / 2
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=5, freeze_text_encoder=True, checkpoint_path=None):
|
| 85 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 86 |
+
best_val_loss = float('inf') # Initialize with a very high value
|
| 87 |
+
|
| 88 |
+
# Freeze text encoder if specified
|
| 89 |
+
if freeze_text_encoder:
|
| 90 |
+
for param in model.text_encoder.parameters():
|
| 91 |
+
param.requires_grad = False
|
| 92 |
+
|
| 93 |
+
# Ensure new layers are trainable
|
| 94 |
+
for param in model.vision_projector.parameters():
|
| 95 |
+
param.requires_grad = True
|
| 96 |
+
for param in model.text_projector.parameters():
|
| 97 |
+
param.requires_grad = True
|
| 98 |
+
|
| 99 |
+
model.to(device)
|
| 100 |
+
|
| 101 |
+
for epoch in range(num_epochs):
|
| 102 |
+
|
| 103 |
+
# Train loop
|
| 104 |
+
model.train()
|
| 105 |
+
total_loss = 0.0
|
| 106 |
+
|
| 107 |
+
print(f"\nEpoch {epoch + 1}/{num_epochs} - Training:")
|
| 108 |
+
train_progress = tqdm(train_loader, desc="Training", leave=True)
|
| 109 |
+
|
| 110 |
+
for image_embeddings, tokenized_captions, attention_masks in train_progress:
|
| 111 |
+
text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)}
|
| 112 |
+
image_embeddings = image_embeddings.to(device)
|
| 113 |
+
|
| 114 |
+
optimizer.zero_grad()
|
| 115 |
+
text_features, vision_features = model(text_inputs, image_embeddings)
|
| 116 |
+
loss = infoNCE_loss(text_features, vision_features)
|
| 117 |
+
loss.backward()
|
| 118 |
+
optimizer.step()
|
| 119 |
+
total_loss += loss.item()
|
| 120 |
+
train_progress.set_postfix(loss=loss.item())
|
| 121 |
+
|
| 122 |
+
scheduler.step()
|
| 123 |
+
|
| 124 |
+
# Validation Loop
|
| 125 |
+
model.eval()
|
| 126 |
+
val_loss = 0.0
|
| 127 |
+
|
| 128 |
+
print(f"\nEpoch {epoch + 1}/{num_epochs} - Validation:")
|
| 129 |
+
val_progress = tqdm(val_loader, desc="Validation", leave=True)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
for image_embeddings, tokenized_captions, attention_masks in val_progress:
|
| 133 |
+
text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)}
|
| 134 |
+
image_embeddings = image_embeddings.to(device)
|
| 135 |
+
|
| 136 |
+
text_features, vision_features = model(text_inputs, image_embeddings)
|
| 137 |
+
loss = infoNCE_loss(text_features, vision_features)
|
| 138 |
+
val_loss += loss.item()
|
| 139 |
+
val_progress.set_postfix(loss=loss.item())
|
| 140 |
+
|
| 141 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 142 |
+
avg_val_loss = val_loss / len(val_loader)
|
| 143 |
+
print(f"\nEpoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
|
| 144 |
+
|
| 145 |
+
# Save best model
|
| 146 |
+
if checkpoint_path is not None:
|
| 147 |
+
if avg_val_loss < best_val_loss:
|
| 148 |
+
best_val_loss = avg_val_loss
|
| 149 |
+
torch.save({
|
| 150 |
+
'epoch': epoch + 1,
|
| 151 |
+
'model_state_dict': model.state_dict(),
|
| 152 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 153 |
+
'val_loss': best_val_loss
|
| 154 |
+
}, checkpoint_path)
|
| 155 |
+
print(f"New Best Model Saved at: {checkpoint_path} (Val Loss: {best_val_loss:.4f})")
|
| 156 |
+
|
| 157 |
+
print("Training completed!")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
# Set random seed for reproducibility
|
| 163 |
+
# torch.manual_seed(42)
|
| 164 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 165 |
+
|
| 166 |
+
# Paths for dataset
|
| 167 |
+
captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'
|
| 168 |
+
# embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'
|
| 169 |
+
embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2'
|
| 170 |
+
|
| 171 |
+
# Initialize datasets and loaders
|
| 172 |
+
full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)
|
| 173 |
+
train_size = int(0.85 * len(full_dataset))
|
| 174 |
+
val_size = len(full_dataset) - train_size
|
| 175 |
+
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
| 176 |
+
|
| 177 |
+
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True)
|
| 178 |
+
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
|
| 179 |
+
|
| 180 |
+
# Initialize model, optimizer, and scheduler
|
| 181 |
+
model = JointNetwork().to(device)
|
| 182 |
+
|
| 183 |
+
checkpoint_path = f"./checkpoints/model_checkpoint_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
|
| 184 |
+
|
| 185 |
+
# **Phase 1 Configuration: Training new layers only**
|
| 186 |
+
initial_lr = 1e-4
|
| 187 |
+
min_lr = 1e-6
|
| 188 |
+
num_epochs = 16
|
| 189 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
|
| 190 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
|
| 191 |
+
|
| 192 |
+
# **Phase 1: Train new layers only, freeze text encoder**
|
| 193 |
+
print("\n### Phase 1: Training new layers only (Text Encoder Frozen) ###")
|
| 194 |
+
train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=True, checkpoint_path=checkpoint_path)
|
| 195 |
+
|
| 196 |
+
# # **Phase 2 Configuration: Fine-tuning with adjusted learning rate**
|
| 197 |
+
# initial_lr = 1e-4
|
| 198 |
+
# min_lr = 1e-6
|
| 199 |
+
# num_epochs = 3
|
| 200 |
+
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
|
| 201 |
+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
|
| 202 |
+
|
| 203 |
+
# print("\n### Phase 2: Fine-tuning text encoder and new layers ###")
|
| 204 |
+
# train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=False, checkpoint_path=checkpoint_path)
|
vision_encoder.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
| 4 |
+
from transformers.image_utils import load_image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ideficsV3(nn.Module):
|
| 8 |
+
def __init__(self, model_name="HuggingFaceTB/SmolVLM-Instruct"):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
# load smolVLM model from huggingface
|
| 12 |
+
self.image_processor = AutoProcessor.from_pretrained(model_name).image_processor
|
| 13 |
+
smolVLM = AutoModelForVision2Seq.from_pretrained(model_name, torch_dtype=torch.float32)
|
| 14 |
+
|
| 15 |
+
# Extract the necessary modules
|
| 16 |
+
self.vision_model = smolVLM.model.vision_model
|
| 17 |
+
|
| 18 |
+
def forward(self, pixel_values):
|
| 19 |
+
|
| 20 |
+
#################################################################
|
| 21 |
+
|
| 22 |
+
# The error ValueError: too many values to unpack (expected 4) occurs because the pixel_values tensor you passed into the model has a shape of [1, 13, 3, 384, 384], while the vision transformer (ViT) expects an input shape of [batch_size, channels, height, width], i.e., a 4D tensor.
|
| 23 |
+
# Your pixel_values tensor is 5D because it contains multiple patches, while the ViT expects a single image or batch of images.
|
| 24 |
+
# You need to flatten the patch dimension (the second dimension, 13) into the batch dimension (1) before passing it to the vision transformer.
|
| 25 |
+
|
| 26 |
+
# Flatten the patch dimension into the batch dimension
|
| 27 |
+
batch_size, num_patches, channels, height, width = pixel_values.shape
|
| 28 |
+
pixel_values = pixel_values.view(batch_size * num_patches, channels, height, width)
|
| 29 |
+
|
| 30 |
+
#################################################################
|
| 31 |
+
|
| 32 |
+
# Run images through the vision transformer
|
| 33 |
+
vision_outputs = self.vision_model(pixel_values)
|
| 34 |
+
x = vision_outputs.last_hidden_state # shape := [batch_size * num_patches, 729, 1152]
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
|
| 40 |
+
# Instantiate truncated model
|
| 41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
+
truncated_model = ideficsV3().to(device).eval()
|
| 43 |
+
truncated_model.eval()
|
| 44 |
+
|
| 45 |
+
image1 = load_image("https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg")
|
| 46 |
+
image2 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
| 47 |
+
|
| 48 |
+
inputs1 = truncated_model.image_processor(images=[image1, image2], return_tensors="pt")
|
| 49 |
+
pixel_values = inputs1.pixel_values.to(model_dtype).to(device)
|
| 50 |
+
|
| 51 |
+
# Pass pixel_values through your truncated model
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
outputs = truncated_model(pixel_values)
|
| 54 |
+
|
| 55 |
+
print(outputs.shape) # Should be [batch_size, 2048] given the projection layer output.
|
| 56 |
+
|