ClemSummer commited on
Commit
7b2eca8
Β·
1 Parent(s): dba8761

Resolve README.md conflict and merge with remote Hugging Face content

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.png
3
+ **/artifacts/
.huggingface/space.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ sdk: docker
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🐍 Use official Python
2
+ FROM python:3.11-slim
3
+
4
+ WORKDIR /app
5
+ COPY . .
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install -r requirements.txt
9
+
10
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -8,4 +8,10 @@ pinned: false
8
  short_description: Clement's AI Lab to demonstrate advanced AI models
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
8
  short_description: Clement's AI Lab to demonstrate advanced AI models
9
  ---
10
 
11
+ # AI Lab
12
+
13
+ This Hugging Face Space includes multiple AI tools:
14
+ - πŸ–ΌοΈ ViT image captioning
15
+ - πŸ“ PPO-based Reddit summarization (coming soon)
16
+
17
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference<<<<<<< HEAD
main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/main.py
2
+
3
+ from fastapi import FastAPI, UploadFile, File
4
+ from fastapi.responses import HTMLResponse
5
+ from fastapi.staticfiles import StaticFiles
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ from vit_captioning.generate import CaptionGenerator
10
+
11
+ app = FastAPI()
12
+
13
+ # Serve static files
14
+ static_dir = Path(__file__).parent / "vit_captioning" / "static"
15
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
16
+
17
+ # βœ… Landing page at `/`
18
+ @app.get("/", response_class=HTMLResponse)
19
+ async def landing():
20
+ return Path("vit_captioning/static/landing.html").read_text()
21
+
22
+ # βœ… Captioning page at `/captioning`
23
+ @app.get("/captioning", response_class=HTMLResponse)
24
+ async def captioning():
25
+ return Path("vit_captioning/static/captioning/index.html").read_text()
26
+
27
+ # βœ… Example: Project 2 placeholder
28
+ @app.get("/project2", response_class=HTMLResponse)
29
+ async def project2():
30
+ return "<h1>Coming Soon: Project 2</h1>"
31
+
32
+ # βœ… Caption generation endpoint for captioning app
33
+ # Keep the path consistent with your JS fetch()!
34
+ caption_generator = CaptionGenerator(
35
+ model_type="CLIPEncoder",
36
+ checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
37
+ quantized=False
38
+ )
39
+
40
+ @app.post("/generate")
41
+ async def generate(file: UploadFile = File(...)):
42
+ temp_file = f"temp_{file.filename}"
43
+ with open(temp_file, "wb") as buffer:
44
+ shutil.copyfileobj(file.file, buffer)
45
+
46
+ captions = caption_generator.generate_caption(temp_file)
47
+ return captions
ppo_summarizer/predict_ppo.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # predict.py
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import PeftModel
6
+ import argparse
7
+ import os
8
+
9
+ # -------------------------------
10
+ # Config
11
+ # -------------------------------
12
+ MODEL_NAME = "Qwen/Qwen3-0.6B-Base"
13
+ CHECKPOINT_DIR = "./artifacts/qwen_loRA"
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ MAX_NEW_TOKENS = 64
16
+
17
+ # -------------------------------
18
+ # Load tokenizer and model
19
+ # -------------------------------
20
+ print("πŸ”„ Loading tokenizer and model...")
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
+ tokenizer.pad_token = tokenizer.eos_token
23
+
24
+ base_model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_NAME,
26
+ torch_dtype=torch.float16,
27
+ device_map="auto"
28
+ )
29
+
30
+ model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
31
+ model.eval()
32
+ model = model.to(DEVICE)
33
+
34
+ # -------------------------------
35
+ # Generate Summary
36
+ # -------------------------------
37
+ def generate_summary(title: str, post: str) -> str:
38
+ prompt = f"Title: {title}\n\nPost: {post}\n\nSummary:"
39
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
40
+
41
+ with torch.no_grad():
42
+ outputs = model.generate(
43
+ **inputs,
44
+ max_new_tokens=MAX_NEW_TOKENS,
45
+ do_sample=True,
46
+ top_k=50,
47
+ top_p=0.95,
48
+ temperature=0.7,
49
+ pad_token_id=tokenizer.pad_token_id,
50
+ use_cache=True
51
+ )
52
+
53
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+ summary = full_output.split("Summary:")[-1].strip()
55
+ return summary
56
+
57
+ # -------------------------------
58
+ # CLI
59
+ # -------------------------------
60
+ if __name__ == "__main__":
61
+ parser = argparse.ArgumentParser(description="Generate summary with trained Qwen PPO model")
62
+ parser.add_argument("--title", type=str, required=True, help="Title of the post")
63
+ parser.add_argument("--post", type=str, required=True, help="Content of the post")
64
+ args = parser.parse_args()
65
+
66
+ print("\nπŸ“˜ Title:", args.title)
67
+ print("πŸ“ Post:", args.post[:100] + ("..." if len(args.post) > 100 else ""))
68
+ print("\nπŸ€– Generating summary...\n")
69
+
70
+ summary = generate_summary(args.title, args.post)
71
+ print("βœ… Summary:\n", summary)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ --extra-index-url https://download.pytorch.org/whl/cpu
4
+ torch==2.6.0+cpu
5
+ numpy<2
6
+ transformers
7
+ pillow
8
+ python-multipart
vit_captioning/generate.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate.py
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer
6
+ from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder
7
+ from vit_captioning.models.decoder import TransformerDecoder
8
+
9
+ import argparse
10
+
11
+
12
+ class CaptionGenerator:
13
+ def __init__(self, model_type: str, checkpoint_path: str, quantized=False):
14
+ print(f"Loading {model_type} | Quantized: {quantized}")
15
+ # Setup device
16
+ if torch.cuda.is_available():
17
+ self.device = torch.device("cuda")
18
+ print("Using NVIDIA CUDA GPU acceleration.")
19
+ elif torch.backends.mps.is_available():
20
+ self.device = torch.device("mps")
21
+ print("Using Apple MPS GPU acceleration.")
22
+ else:
23
+ self.device = torch.device("cpu")
24
+ print("No GPU found, falling back to CPU.")
25
+
26
+ # Load tokenizer
27
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
28
+
29
+ # Select encoder, processor, output dim
30
+ if model_type == "ViTEncoder":
31
+ self.encoder = ViTEncoder().to(self.device)
32
+ self.encoder_dim = 768
33
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
34
+ elif model_type == "CLIPEncoder":
35
+ self.encoder = CLIPEncoder().to(self.device)
36
+ self.encoder_dim = 512
37
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
38
+ else:
39
+ raise ValueError("Unknown model type")
40
+
41
+ if quantized:
42
+ print("Applying dynamic quantization to encoder...")
43
+ self.encoder = torch.ao.quantization.quantize_dynamic(
44
+ self.encoder,
45
+ {torch.nn.Linear},
46
+ dtype=torch.qint8
47
+ )
48
+
49
+ # Initialize decoder
50
+ self.decoder = TransformerDecoder(
51
+ vocab_size=30522,
52
+ hidden_dim=self.encoder_dim,
53
+ encoder_dim=self.encoder_dim
54
+ ).to(self.device)
55
+
56
+ # Load checkpoint
57
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
58
+ self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
59
+ self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
60
+ self.encoder.eval()
61
+ self.decoder.eval()
62
+
63
+ def generate_caption(self, image_path: str) -> dict:
64
+ image = Image.open(image_path).convert("RGB")
65
+ encoding = self.processor(images=image, return_tensors='pt')
66
+ pixel_values = encoding['pixel_values'].to(self.device)
67
+
68
+ captions = {}
69
+
70
+ with torch.no_grad():
71
+ encoder_outputs = self.encoder(pixel_values)
72
+
73
+ # Greedy
74
+ caption_ids = self.decoder.generate(encoder_outputs, mode="greedy")
75
+ captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
76
+
77
+ # Top-k
78
+ caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30)
79
+ captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
80
+
81
+ # Top-p
82
+ caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92)
83
+ captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
84
+
85
+ return captions
86
+
87
+
88
+ if __name__ == "__main__":
89
+ # CLI usage
90
+ parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.")
91
+ parser.add_argument("--model", type=str, default="ViTEncoder",
92
+ choices=["ViTEncoder", "CLIPEncoder"],
93
+ help="Choose encoder: ViTEncoder or CLIPEncoder")
94
+ parser.add_argument("--checkpoint", type=str, required=True,
95
+ help="Path to the .pth checkpoint file")
96
+ parser.add_argument("--image", type=str, required=True,
97
+ help="Path to input image file")
98
+ parser.add_argument(
99
+ "--quantized",
100
+ action="store_true",
101
+ help="Load encoder with dynamic quantization"
102
+ ) ### βœ… ADDED
103
+
104
+ args = parser.parse_args()
105
+
106
+ generator = CaptionGenerator(
107
+ model_type=args.model,
108
+ checkpoint_path=args.checkpoint
109
+ )
110
+
111
+ captions = generator.generate_caption(args.image)
112
+
113
+ print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}")
114
+ print(f"Top-k (diverse, creative): {captions['topk']}")
115
+ print(f"Top-p (diverse, human-like): {captions['topp']}")
vit_captioning/models/decoder.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # decoder.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_model, max_len=5000):
10
+ super(PositionalEncoding, self).__init__()
11
+
12
+ pe = torch.zeros(max_len, d_model) # [max_len, d_model]
13
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
14
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
15
+
16
+ pe[:, 0::2] = torch.sin(position * div_term) # dim 2i
17
+ pe[:, 1::2] = torch.cos(position * div_term) # dim 2i+1
18
+
19
+ pe = pe.unsqueeze(1) # [max_len, 1, d_model]
20
+ self.register_buffer('pe', pe)
21
+
22
+ def forward(self, x):
23
+ # x: [seq_len, batch_size, d_model]
24
+ x = x + self.pe[:x.size(0)]
25
+ return x
26
+
27
+
28
+ def generate_square_subsequent_mask(sz):
29
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
30
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
31
+ return mask
32
+
33
+
34
+ class TransformerDecoder(nn.Module):
35
+ def __init__(self, vocab_size, hidden_dim=512, encoder_dim=768, num_layers=2):
36
+ super(TransformerDecoder, self).__init__()
37
+
38
+ self.vocab_size = vocab_size
39
+ self.embedding = nn.Embedding(vocab_size, hidden_dim)
40
+ self.positional_encoding = PositionalEncoding(hidden_dim)
41
+
42
+ decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8)
43
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
44
+
45
+ self.fc_out = nn.Linear(hidden_dim, vocab_size)
46
+
47
+ # Project ViT encoder output to decoder hidden_dim if needed
48
+ self.encoder_projection = nn.Linear(encoder_dim, hidden_dim)
49
+
50
+ def forward(self, input_ids, encoder_outputs, tgt_attention_mask=None):
51
+ embedded = self.embedding(input_ids).permute(1, 0, 2)
52
+ embedded = self.positional_encoding(embedded)
53
+
54
+ memory = self.encoder_projection(encoder_outputs).unsqueeze(0)
55
+
56
+ tgt_mask = generate_square_subsequent_mask(embedded.size(0)).to(embedded.device)
57
+
58
+ if tgt_attention_mask is not None:
59
+ tgt_key_padding_mask = ~tgt_attention_mask.bool()
60
+ else:
61
+ tgt_key_padding_mask = None
62
+
63
+ output = self.transformer_decoder(
64
+ tgt=embedded,
65
+ memory=memory,
66
+ tgt_mask=tgt_mask,
67
+ tgt_key_padding_mask=tgt_key_padding_mask
68
+ )
69
+
70
+ output = self.fc_out(output).permute(1, 0, 2)
71
+ return output
72
+
73
+ def generate(
74
+ self,
75
+ encoder_outputs,
76
+ start_token_id=101, # [CLS] token for BERT
77
+ eos_token_id=102,
78
+ max_length=50,
79
+ mode="greedy", # "greedy", "beam", "topk", "topp"
80
+ num_beams=3,
81
+ top_k=50,
82
+ top_p=0.95,
83
+ length_penalty=1.0
84
+ ):
85
+
86
+ device = encoder_outputs.device
87
+
88
+ """
89
+ Generate caption using specified decoding mode.
90
+ """
91
+ batch_size = encoder_outputs.size(0)
92
+ input_ids = torch.full(
93
+ (batch_size, 1),
94
+ start_token_id,
95
+ dtype=torch.long,
96
+ device=device
97
+ )
98
+
99
+ if mode == "beam":
100
+ return self._generate_beam_search(
101
+ encoder_outputs,
102
+ input_ids,
103
+ max_length,
104
+ eos_token_id,
105
+ num_beams,
106
+ length_penalty
107
+ )
108
+
109
+ # Greedy or sampling
110
+ generated = input_ids
111
+
112
+ for _ in range(max_length):
113
+ logits = self.forward(generated, encoder_outputs) # (batch, seq_len, vocab)
114
+ next_token_logits = logits[:, -1, :] # (batch, vocab)
115
+
116
+ if mode == "greedy":
117
+ next_token = next_token_logits.argmax(dim=-1, keepdim=True)
118
+
119
+ elif mode == "topk":
120
+ probs = F.softmax(next_token_logits, dim=-1)
121
+ topk_probs, topk_indices = torch.topk(probs, top_k)
122
+ next_token = topk_indices[
123
+ torch.arange(probs.size(0)),
124
+ torch.multinomial(topk_probs, num_samples=1).squeeze(-1)
125
+ ].unsqueeze(-1)
126
+
127
+ elif mode == "topp":
128
+ probs = F.softmax(next_token_logits, dim=-1)
129
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
130
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
131
+
132
+ # Remove tokens with cumulative probs above threshold
133
+ sorted_mask = cumulative_probs <= top_p
134
+ sorted_mask[..., 0] = 1 # Always keep at least 1 token
135
+
136
+ filtered_probs = sorted_probs * sorted_mask
137
+ filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True)
138
+
139
+ next_token = sorted_indices[
140
+ torch.arange(probs.size(0)),
141
+ torch.multinomial(filtered_probs, num_samples=1).squeeze(-1)
142
+ ].unsqueeze(-1)
143
+
144
+ else:
145
+ raise ValueError(f"Unknown mode: {mode}")
146
+
147
+ generated = torch.cat((generated, next_token), dim=1)
148
+
149
+ if eos_token_id is not None:
150
+ if (next_token == eos_token_id).all():
151
+ break
152
+
153
+ return generated[:, 1:] # Remove BOS if needed
154
+
155
+ def _generate_beam_search(
156
+ self,
157
+ encoder_outputs,
158
+ input_ids,
159
+ max_length=50,
160
+ eos_token_id=102,
161
+ num_beams=3,
162
+ length_penalty=1.0
163
+ ):
164
+ """
165
+ Custom beam search decoder for batch_size = 1.
166
+ """
167
+ device = encoder_outputs.device
168
+ batch_size = encoder_outputs.size(0)
169
+ vocab_size = self.vocab_size
170
+
171
+ # Assume batch_size = 1 for simplicity
172
+ assert batch_size == 1, "Basic beam search only supports batch size 1 here."
173
+
174
+ # Initialize beams
175
+ beam_sequences = [input_ids] * num_beams
176
+ beam_scores = torch.zeros(num_beams, device=device)
177
+
178
+ finished_sequences = []
179
+ finished_scores = []
180
+
181
+ for step in range(max_length):
182
+ all_candidates = []
183
+
184
+ for beam_idx in range(num_beams):
185
+ seq = beam_sequences[beam_idx]
186
+ score = beam_scores[beam_idx]
187
+
188
+ logits = self.forward(seq, encoder_outputs) # (1, seq_len, vocab)
189
+ next_token_logits = logits[:, -1, :] # (1, vocab)
190
+ log_probs = F.log_softmax(next_token_logits, dim=-1).squeeze(0) # (vocab,)
191
+
192
+ for token_id in range(vocab_size):
193
+ new_seq = torch.cat([seq, torch.tensor([[token_id]], device=device)], dim=1)
194
+ new_score = score + log_probs[token_id]
195
+ all_candidates.append((new_seq, new_score))
196
+
197
+ # Get top beams
198
+ all_candidates.sort(key=lambda x: x[1], reverse=True)
199
+ beam_sequences = []
200
+ beam_scores = []
201
+
202
+ for seq, score in all_candidates[:num_beams]:
203
+ if eos_token_id is not None and seq[0, -1].item() == eos_token_id:
204
+ finished_sequences.append(seq)
205
+ finished_scores.append(score)
206
+ else:
207
+ beam_sequences.append(seq)
208
+ beam_scores.append(score)
209
+
210
+ beam_scores = torch.stack(beam_scores) if beam_scores else torch.tensor([], device=device)
211
+
212
+ # Early stopping if all beams ended
213
+ if len(beam_sequences) == 0:
214
+ break
215
+
216
+ # Add unfinished beams to finished
217
+ if not finished_sequences:
218
+ finished_sequences = beam_sequences
219
+ finished_scores = beam_scores
220
+
221
+ # Length penalty
222
+ finished_scores = [s / (len(seq[0]) ** length_penalty) for seq, s in zip(finished_sequences, finished_scores)]
223
+
224
+ # Pick best
225
+ best_idx = torch.tensor(finished_scores).argmax().item()
226
+ best_seq = finished_sequences[best_idx]
227
+
228
+ return best_seq[:, 1:] # remove BOS if needed
vit_captioning/models/encoder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/encoder.py
2
+ from transformers import ViTModel, ViTImageProcessor, CLIPModel
3
+ import torch.nn as nn
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ import torch.nn as nn
9
+
10
+ class ViTEncoder(nn.Module):
11
+ def __init__(self): # Make decoder_dim configurable!
12
+ super(ViTEncoder, self).__init__()
13
+
14
+ #weights = ViT_B_16_Weights.DEFAULT
15
+ self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
16
+
17
+ def forward(self, pixel_values):
18
+
19
+ # ViTModel - output shape = [batch, seq_len, hidden]
20
+ outputs = self.vit(pixel_values=pixel_values)
21
+
22
+ # Take CLS: last_hidden_state
23
+
24
+ cls_embedding = outputs.last_hidden_state[:, 0]
25
+ return cls_embedding
26
+
27
+ # encoder.py
28
+ from transformers import CLIPModel
29
+
30
+ class CLIPEncoder(nn.Module):
31
+ def __init__(self):
32
+ super(CLIPEncoder, self).__init__()
33
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
34
+
35
+ def forward(self, pixel_values):
36
+ # βœ… Directly get the pooled image features (already the final representation)
37
+ image_features = self.clip.get_image_features(pixel_values=pixel_values)
38
+ return image_features # shape: [batch_size, hidden_dim]
vit_captioning/static/captioning/index.html ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>πŸ€– Image πŸ–ΌοΈ Captioning</title>
6
+ <meta name="viewport" content="width=device-width, initial-scale=1">
7
+
8
+ <!-- βœ… Tailwind CDN -->
9
+ <script src="https://cdn.tailwindcss.com"></script>
10
+ </head>
11
+
12
+ <body class="bg-gray-100 flex items-center justify-center min-h-screen">
13
+ <a href="/" class="absolute top-4 left-4 text-blue-600 hover:text-blue-800 text-sm font-semibold flex items-center">
14
+ <!-- back arrow -->
15
+ <svg class="w-5 h-5 mr-1" fill="none" stroke="currentColor" stroke-width="2" viewBox="0 0 24 24"
16
+ xmlns="http://www.w3.org/2000/svg">
17
+ <path stroke-linecap="round" stroke-linejoin="round" d="M15 19l-7-7 7-7"></path>
18
+ </svg>
19
+ Back to Home
20
+ </a>
21
+
22
+ <div class="bg-white p-8 rounded-xl shadow-md w-full max-w-md text-center">
23
+ <h1 class="text-2xl font-bold mb-4 text-gray-800">AI Image Captioning</h1>
24
+
25
+ <!-- Upload Form -->
26
+ <form id="uploadForm" class="space-y-4">
27
+ <input
28
+ type="file"
29
+ id="fileInput"
30
+ accept="image/*"
31
+ required
32
+ class="block w-full text-sm text-gray-700 file:mr-4 file:py-2 file:px-4 file:rounded-full file:border-0 file:text-sm file:font-semibold file:bg-blue-50 file:text-blue-700 hover:file:bg-blue-100"
33
+ />
34
+
35
+ <!-- Live Image Preview -->
36
+ <div id="previewContainer" class="mt-4 hidden">
37
+ <img id="previewImage" src="#" alt="Preview" class="mx-auto max-h-64 rounded-md shadow" />
38
+ </div>
39
+
40
+ <button
41
+ type="submit"
42
+ class="w-full bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2 px-4 rounded-lg transition"
43
+ >
44
+ Generate Captions
45
+ </button>
46
+ </form>
47
+
48
+ <!-- Captions -->
49
+ <div id="result" class="mt-6 text-left hidden">
50
+ <h2 class="text-lg font-semibold mb-2 text-gray-700">Captions:</h2>
51
+ <p><strong>Factual πŸ€–:</strong> <span id="greedy" class="text-gray-800"></span></p>
52
+ <p><strong>Creative πŸ€ͺ:</strong> <span id="topk" class="text-gray-800"></span></p>
53
+ <p><strong>Human like πŸ«€:</strong> <span id="topp" class="text-gray-800"></span></p>
54
+ </div>
55
+ </div>
56
+
57
+ <script>
58
+ const fileInput = document.getElementById('fileInput');
59
+ const previewContainer = document.getElementById('previewContainer');
60
+ const previewImage = document.getElementById('previewImage');
61
+ const form = document.getElementById('uploadForm');
62
+ const result = document.getElementById('result');
63
+
64
+ // βœ… Live preview + clear old captions
65
+ fileInput.addEventListener('change', () => {
66
+ const file = fileInput.files[0];
67
+ if (file) {
68
+ const reader = new FileReader();
69
+ reader.onload = e => {
70
+ previewImage.src = e.target.result;
71
+ previewContainer.classList.remove('hidden');
72
+ };
73
+ reader.readAsDataURL(file);
74
+
75
+ // Clear old captions
76
+ document.getElementById('greedy').innerText = "";
77
+ document.getElementById('topk').innerText = "";
78
+ document.getElementById('topp').innerText = "";
79
+ result.classList.add('hidden');
80
+
81
+ } else {
82
+ previewContainer.classList.add('hidden');
83
+ }
84
+ });
85
+
86
+ // βœ… Submit form
87
+ form.addEventListener('submit', async e => {
88
+ e.preventDefault();
89
+ const file = fileInput.files[0];
90
+ const formData = new FormData();
91
+ formData.append('file', file);
92
+
93
+ const res = await fetch('/generate', {
94
+ method: 'POST',
95
+ body: formData
96
+ });
97
+
98
+ const data = await res.json();
99
+ document.getElementById('greedy').innerText = data.greedy || "N/A";
100
+ document.getElementById('topk').innerText = data.topk || "N/A";
101
+ document.getElementById('topp').innerText = data.topp || "N/A";
102
+ result.classList.remove('hidden');
103
+ });
104
+ </script>
105
+
106
+ <!-- Floating Help Button -->
107
+ <button id="helpButton"
108
+ class="fixed bottom-4 right-4 bg-blue-600 text-white rounded-full w-12 h-12 text-2xl font-bold shadow-lg hover:bg-blue-700 transition">
109
+ ?
110
+ </button>
111
+
112
+ <!-- Help Modal -->
113
+ <div id="helpModal" class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center hidden">
114
+ <div class="bg-white rounded-lg p-6 max-w-sm w-full shadow-lg text-left">
115
+ <h2 class="text-xl font-semibold mb-4">πŸ€– Image Captioning</h2>
116
+ <p class="text-gray-700 mb-4">
117
+ Please upload a picture / image and press "Generate Captions", the model will generate captions for it.
118
+ The model uses google/vit-base-patch16-224-in21k or openai/clip-vit-base-patch32
119
+ as image encoder, trained together with a customer transformer decoder to generate captions.<br>
120
+ The available caption styles are: "Factual πŸ€–", "Creative πŸ€ͺ", and "Human like πŸ«€",
121
+ which are actually argmax (greedy), top-K and top-P respectively.
122
+
123
+ </p>
124
+ <button id="closeModal"
125
+ class="mt-2 bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700">
126
+ Close
127
+ </button>
128
+ </div>
129
+ </div>
130
+
131
+ <script>
132
+ const helpButton = document.getElementById('helpButton');
133
+ const helpModal = document.getElementById('helpModal');
134
+ const closeModal = document.getElementById('closeModal');
135
+
136
+ helpButton.addEventListener('click', () => {
137
+ helpModal.classList.remove('hidden');
138
+ });
139
+
140
+ closeModal.addEventListener('click', () => {
141
+ helpModal.classList.add('hidden');
142
+ });
143
+
144
+ // Optional: close modal when clicking outside the modal box
145
+ helpModal.addEventListener('click', (e) => {
146
+ if (e.target === helpModal) {
147
+ helpModal.classList.add('hidden');
148
+ }
149
+ });
150
+ </script>
151
+
152
+ </body>
153
+ </html>
vit_captioning/static/landing.html ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>Clement's AI Lab</title>
6
+ <meta name="viewport" content="width=device-width, initial-scale=1">
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ </head>
9
+ <body class="bg-gray-100 flex items-center justify-center min-h-screen">
10
+ <div class="max-w-md w-full p-6 text-center">
11
+ <h1 class="text-3xl font-bold mb-6 text-gray-800">πŸš€ Clement's AI Lab</h1>
12
+ <div class="space-y-4">
13
+ <a href="/captioning" class="block w-full bg-blue-600 hover:bg-blue-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
14
+ πŸ–ΌοΈ Image Captioning
15
+ </a>
16
+ <a href="/project2" class="block w-full bg-green-600 hover:bg-green-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
17
+ πŸ€– Coming Soon: Word calculator
18
+ </a>
19
+ <!-- Add more project links here -->
20
+ </div>
21
+ </div>
22
+
23
+ <!-- Floating Help Button -->
24
+ <button id="helpButton"
25
+ class="fixed bottom-4 right-4 bg-blue-600 text-white rounded-full w-12 h-12 text-2xl font-bold shadow-lg hover:bg-blue-700 transition">
26
+ ?
27
+ </button>
28
+
29
+ <!-- Help Modal -->
30
+ <div id="helpModal" class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center hidden">
31
+ <div class="bg-white rounded-lg p-6 max-w-sm w-full shadow-lg text-left">
32
+ <h2 class="text-xl font-semibold mb-4">Clement's AI Lab</h2>
33
+ <p class="text-gray-700 mb-4">
34
+ Welcome! I'm Clement, and I've built these interactive models for you to experiment with. Whether you're curious about AI or just want to have some fun, there's something here for everyone.<!-- Page-specific explanation goes here -->
35
+ </p>
36
+ <button id="closeModal"
37
+ class="mt-2 bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700">
38
+ Close
39
+ </button>
40
+ </div>
41
+ </div>
42
+
43
+ <script>
44
+ const helpButton = document.getElementById('helpButton');
45
+ const helpModal = document.getElementById('helpModal');
46
+ const closeModal = document.getElementById('closeModal');
47
+
48
+ helpButton.addEventListener('click', () => {
49
+ helpModal.classList.remove('hidden');
50
+ });
51
+
52
+ closeModal.addEventListener('click', () => {
53
+ helpModal.classList.add('hidden');
54
+ });
55
+
56
+ // Optional: close modal when clicking outside the modal box
57
+ helpModal.addEventListener('click', (e) => {
58
+ if (e.target === helpModal) {
59
+ helpModal.classList.add('hidden');
60
+ }
61
+ });
62
+ </script>
63
+
64
+ </body>
65
+ </html>