Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
Β·
7b2eca8
1
Parent(s):
dba8761
Resolve README.md conflict and merge with remote Hugging Face content
Browse files- .gitignore +3 -0
- .huggingface/space.yaml +1 -0
- Dockerfile +10 -0
- README.md +7 -1
- main.py +47 -0
- ppo_summarizer/predict_ppo.py +71 -0
- requirements.txt +8 -0
- vit_captioning/generate.py +115 -0
- vit_captioning/models/decoder.py +228 -0
- vit_captioning/models/encoder.py +38 -0
- vit_captioning/static/captioning/index.html +153 -0
- vit_captioning/static/landing.html +65 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.png
|
3 |
+
**/artifacts/
|
.huggingface/space.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
sdk: docker
|
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# π Use official Python
|
2 |
+
FROM python:3.11-slim
|
3 |
+
|
4 |
+
WORKDIR /app
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
RUN pip install --upgrade pip
|
8 |
+
RUN pip install -r requirements.txt
|
9 |
+
|
10 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
@@ -8,4 +8,10 @@ pinned: false
|
|
8 |
short_description: Clement's AI Lab to demonstrate advanced AI models
|
9 |
---
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
short_description: Clement's AI Lab to demonstrate advanced AI models
|
9 |
---
|
10 |
|
11 |
+
# AI Lab
|
12 |
+
|
13 |
+
This Hugging Face Space includes multiple AI tools:
|
14 |
+
- πΌοΈ ViT image captioning
|
15 |
+
- π PPO-based Reddit summarization (coming soon)
|
16 |
+
|
17 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference<<<<<<< HEAD
|
main.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/main.py
|
2 |
+
|
3 |
+
from fastapi import FastAPI, UploadFile, File
|
4 |
+
from fastapi.responses import HTMLResponse
|
5 |
+
from fastapi.staticfiles import StaticFiles
|
6 |
+
import shutil
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
from vit_captioning.generate import CaptionGenerator
|
10 |
+
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
# Serve static files
|
14 |
+
static_dir = Path(__file__).parent / "vit_captioning" / "static"
|
15 |
+
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
16 |
+
|
17 |
+
# β
Landing page at `/`
|
18 |
+
@app.get("/", response_class=HTMLResponse)
|
19 |
+
async def landing():
|
20 |
+
return Path("vit_captioning/static/landing.html").read_text()
|
21 |
+
|
22 |
+
# β
Captioning page at `/captioning`
|
23 |
+
@app.get("/captioning", response_class=HTMLResponse)
|
24 |
+
async def captioning():
|
25 |
+
return Path("vit_captioning/static/captioning/index.html").read_text()
|
26 |
+
|
27 |
+
# β
Example: Project 2 placeholder
|
28 |
+
@app.get("/project2", response_class=HTMLResponse)
|
29 |
+
async def project2():
|
30 |
+
return "<h1>Coming Soon: Project 2</h1>"
|
31 |
+
|
32 |
+
# β
Caption generation endpoint for captioning app
|
33 |
+
# Keep the path consistent with your JS fetch()!
|
34 |
+
caption_generator = CaptionGenerator(
|
35 |
+
model_type="CLIPEncoder",
|
36 |
+
checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
|
37 |
+
quantized=False
|
38 |
+
)
|
39 |
+
|
40 |
+
@app.post("/generate")
|
41 |
+
async def generate(file: UploadFile = File(...)):
|
42 |
+
temp_file = f"temp_{file.filename}"
|
43 |
+
with open(temp_file, "wb") as buffer:
|
44 |
+
shutil.copyfileobj(file.file, buffer)
|
45 |
+
|
46 |
+
captions = caption_generator.generate_caption(temp_file)
|
47 |
+
return captions
|
ppo_summarizer/predict_ppo.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# predict.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from peft import PeftModel
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
|
9 |
+
# -------------------------------
|
10 |
+
# Config
|
11 |
+
# -------------------------------
|
12 |
+
MODEL_NAME = "Qwen/Qwen3-0.6B-Base"
|
13 |
+
CHECKPOINT_DIR = "./artifacts/qwen_loRA"
|
14 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
MAX_NEW_TOKENS = 64
|
16 |
+
|
17 |
+
# -------------------------------
|
18 |
+
# Load tokenizer and model
|
19 |
+
# -------------------------------
|
20 |
+
print("π Loading tokenizer and model...")
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
22 |
+
tokenizer.pad_token = tokenizer.eos_token
|
23 |
+
|
24 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
25 |
+
MODEL_NAME,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
device_map="auto"
|
28 |
+
)
|
29 |
+
|
30 |
+
model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
|
31 |
+
model.eval()
|
32 |
+
model = model.to(DEVICE)
|
33 |
+
|
34 |
+
# -------------------------------
|
35 |
+
# Generate Summary
|
36 |
+
# -------------------------------
|
37 |
+
def generate_summary(title: str, post: str) -> str:
|
38 |
+
prompt = f"Title: {title}\n\nPost: {post}\n\nSummary:"
|
39 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
40 |
+
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = model.generate(
|
43 |
+
**inputs,
|
44 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
45 |
+
do_sample=True,
|
46 |
+
top_k=50,
|
47 |
+
top_p=0.95,
|
48 |
+
temperature=0.7,
|
49 |
+
pad_token_id=tokenizer.pad_token_id,
|
50 |
+
use_cache=True
|
51 |
+
)
|
52 |
+
|
53 |
+
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
54 |
+
summary = full_output.split("Summary:")[-1].strip()
|
55 |
+
return summary
|
56 |
+
|
57 |
+
# -------------------------------
|
58 |
+
# CLI
|
59 |
+
# -------------------------------
|
60 |
+
if __name__ == "__main__":
|
61 |
+
parser = argparse.ArgumentParser(description="Generate summary with trained Qwen PPO model")
|
62 |
+
parser.add_argument("--title", type=str, required=True, help="Title of the post")
|
63 |
+
parser.add_argument("--post", type=str, required=True, help="Content of the post")
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
print("\nπ Title:", args.title)
|
67 |
+
print("π Post:", args.post[:100] + ("..." if len(args.post) > 100 else ""))
|
68 |
+
print("\nπ€ Generating summary...\n")
|
69 |
+
|
70 |
+
summary = generate_summary(args.title, args.post)
|
71 |
+
print("β
Summary:\n", summary)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
4 |
+
torch==2.6.0+cpu
|
5 |
+
numpy<2
|
6 |
+
transformers
|
7 |
+
pillow
|
8 |
+
python-multipart
|
vit_captioning/generate.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# generate.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer
|
6 |
+
from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder
|
7 |
+
from vit_captioning.models.decoder import TransformerDecoder
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
|
12 |
+
class CaptionGenerator:
|
13 |
+
def __init__(self, model_type: str, checkpoint_path: str, quantized=False):
|
14 |
+
print(f"Loading {model_type} | Quantized: {quantized}")
|
15 |
+
# Setup device
|
16 |
+
if torch.cuda.is_available():
|
17 |
+
self.device = torch.device("cuda")
|
18 |
+
print("Using NVIDIA CUDA GPU acceleration.")
|
19 |
+
elif torch.backends.mps.is_available():
|
20 |
+
self.device = torch.device("mps")
|
21 |
+
print("Using Apple MPS GPU acceleration.")
|
22 |
+
else:
|
23 |
+
self.device = torch.device("cpu")
|
24 |
+
print("No GPU found, falling back to CPU.")
|
25 |
+
|
26 |
+
# Load tokenizer
|
27 |
+
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
28 |
+
|
29 |
+
# Select encoder, processor, output dim
|
30 |
+
if model_type == "ViTEncoder":
|
31 |
+
self.encoder = ViTEncoder().to(self.device)
|
32 |
+
self.encoder_dim = 768
|
33 |
+
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
34 |
+
elif model_type == "CLIPEncoder":
|
35 |
+
self.encoder = CLIPEncoder().to(self.device)
|
36 |
+
self.encoder_dim = 512
|
37 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
38 |
+
else:
|
39 |
+
raise ValueError("Unknown model type")
|
40 |
+
|
41 |
+
if quantized:
|
42 |
+
print("Applying dynamic quantization to encoder...")
|
43 |
+
self.encoder = torch.ao.quantization.quantize_dynamic(
|
44 |
+
self.encoder,
|
45 |
+
{torch.nn.Linear},
|
46 |
+
dtype=torch.qint8
|
47 |
+
)
|
48 |
+
|
49 |
+
# Initialize decoder
|
50 |
+
self.decoder = TransformerDecoder(
|
51 |
+
vocab_size=30522,
|
52 |
+
hidden_dim=self.encoder_dim,
|
53 |
+
encoder_dim=self.encoder_dim
|
54 |
+
).to(self.device)
|
55 |
+
|
56 |
+
# Load checkpoint
|
57 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
58 |
+
self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
|
59 |
+
self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
|
60 |
+
self.encoder.eval()
|
61 |
+
self.decoder.eval()
|
62 |
+
|
63 |
+
def generate_caption(self, image_path: str) -> dict:
|
64 |
+
image = Image.open(image_path).convert("RGB")
|
65 |
+
encoding = self.processor(images=image, return_tensors='pt')
|
66 |
+
pixel_values = encoding['pixel_values'].to(self.device)
|
67 |
+
|
68 |
+
captions = {}
|
69 |
+
|
70 |
+
with torch.no_grad():
|
71 |
+
encoder_outputs = self.encoder(pixel_values)
|
72 |
+
|
73 |
+
# Greedy
|
74 |
+
caption_ids = self.decoder.generate(encoder_outputs, mode="greedy")
|
75 |
+
captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
|
76 |
+
|
77 |
+
# Top-k
|
78 |
+
caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30)
|
79 |
+
captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
|
80 |
+
|
81 |
+
# Top-p
|
82 |
+
caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92)
|
83 |
+
captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
|
84 |
+
|
85 |
+
return captions
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
# CLI usage
|
90 |
+
parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.")
|
91 |
+
parser.add_argument("--model", type=str, default="ViTEncoder",
|
92 |
+
choices=["ViTEncoder", "CLIPEncoder"],
|
93 |
+
help="Choose encoder: ViTEncoder or CLIPEncoder")
|
94 |
+
parser.add_argument("--checkpoint", type=str, required=True,
|
95 |
+
help="Path to the .pth checkpoint file")
|
96 |
+
parser.add_argument("--image", type=str, required=True,
|
97 |
+
help="Path to input image file")
|
98 |
+
parser.add_argument(
|
99 |
+
"--quantized",
|
100 |
+
action="store_true",
|
101 |
+
help="Load encoder with dynamic quantization"
|
102 |
+
) ### β
ADDED
|
103 |
+
|
104 |
+
args = parser.parse_args()
|
105 |
+
|
106 |
+
generator = CaptionGenerator(
|
107 |
+
model_type=args.model,
|
108 |
+
checkpoint_path=args.checkpoint
|
109 |
+
)
|
110 |
+
|
111 |
+
captions = generator.generate_caption(args.image)
|
112 |
+
|
113 |
+
print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}")
|
114 |
+
print(f"Top-k (diverse, creative): {captions['topk']}")
|
115 |
+
print(f"Top-p (diverse, human-like): {captions['topp']}")
|
vit_captioning/models/decoder.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# decoder.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model, max_len=5000):
|
10 |
+
super(PositionalEncoding, self).__init__()
|
11 |
+
|
12 |
+
pe = torch.zeros(max_len, d_model) # [max_len, d_model]
|
13 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
|
14 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
15 |
+
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term) # dim 2i
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term) # dim 2i+1
|
18 |
+
|
19 |
+
pe = pe.unsqueeze(1) # [max_len, 1, d_model]
|
20 |
+
self.register_buffer('pe', pe)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
# x: [seq_len, batch_size, d_model]
|
24 |
+
x = x + self.pe[:x.size(0)]
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
def generate_square_subsequent_mask(sz):
|
29 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
30 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
31 |
+
return mask
|
32 |
+
|
33 |
+
|
34 |
+
class TransformerDecoder(nn.Module):
|
35 |
+
def __init__(self, vocab_size, hidden_dim=512, encoder_dim=768, num_layers=2):
|
36 |
+
super(TransformerDecoder, self).__init__()
|
37 |
+
|
38 |
+
self.vocab_size = vocab_size
|
39 |
+
self.embedding = nn.Embedding(vocab_size, hidden_dim)
|
40 |
+
self.positional_encoding = PositionalEncoding(hidden_dim)
|
41 |
+
|
42 |
+
decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8)
|
43 |
+
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
44 |
+
|
45 |
+
self.fc_out = nn.Linear(hidden_dim, vocab_size)
|
46 |
+
|
47 |
+
# Project ViT encoder output to decoder hidden_dim if needed
|
48 |
+
self.encoder_projection = nn.Linear(encoder_dim, hidden_dim)
|
49 |
+
|
50 |
+
def forward(self, input_ids, encoder_outputs, tgt_attention_mask=None):
|
51 |
+
embedded = self.embedding(input_ids).permute(1, 0, 2)
|
52 |
+
embedded = self.positional_encoding(embedded)
|
53 |
+
|
54 |
+
memory = self.encoder_projection(encoder_outputs).unsqueeze(0)
|
55 |
+
|
56 |
+
tgt_mask = generate_square_subsequent_mask(embedded.size(0)).to(embedded.device)
|
57 |
+
|
58 |
+
if tgt_attention_mask is not None:
|
59 |
+
tgt_key_padding_mask = ~tgt_attention_mask.bool()
|
60 |
+
else:
|
61 |
+
tgt_key_padding_mask = None
|
62 |
+
|
63 |
+
output = self.transformer_decoder(
|
64 |
+
tgt=embedded,
|
65 |
+
memory=memory,
|
66 |
+
tgt_mask=tgt_mask,
|
67 |
+
tgt_key_padding_mask=tgt_key_padding_mask
|
68 |
+
)
|
69 |
+
|
70 |
+
output = self.fc_out(output).permute(1, 0, 2)
|
71 |
+
return output
|
72 |
+
|
73 |
+
def generate(
|
74 |
+
self,
|
75 |
+
encoder_outputs,
|
76 |
+
start_token_id=101, # [CLS] token for BERT
|
77 |
+
eos_token_id=102,
|
78 |
+
max_length=50,
|
79 |
+
mode="greedy", # "greedy", "beam", "topk", "topp"
|
80 |
+
num_beams=3,
|
81 |
+
top_k=50,
|
82 |
+
top_p=0.95,
|
83 |
+
length_penalty=1.0
|
84 |
+
):
|
85 |
+
|
86 |
+
device = encoder_outputs.device
|
87 |
+
|
88 |
+
"""
|
89 |
+
Generate caption using specified decoding mode.
|
90 |
+
"""
|
91 |
+
batch_size = encoder_outputs.size(0)
|
92 |
+
input_ids = torch.full(
|
93 |
+
(batch_size, 1),
|
94 |
+
start_token_id,
|
95 |
+
dtype=torch.long,
|
96 |
+
device=device
|
97 |
+
)
|
98 |
+
|
99 |
+
if mode == "beam":
|
100 |
+
return self._generate_beam_search(
|
101 |
+
encoder_outputs,
|
102 |
+
input_ids,
|
103 |
+
max_length,
|
104 |
+
eos_token_id,
|
105 |
+
num_beams,
|
106 |
+
length_penalty
|
107 |
+
)
|
108 |
+
|
109 |
+
# Greedy or sampling
|
110 |
+
generated = input_ids
|
111 |
+
|
112 |
+
for _ in range(max_length):
|
113 |
+
logits = self.forward(generated, encoder_outputs) # (batch, seq_len, vocab)
|
114 |
+
next_token_logits = logits[:, -1, :] # (batch, vocab)
|
115 |
+
|
116 |
+
if mode == "greedy":
|
117 |
+
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
118 |
+
|
119 |
+
elif mode == "topk":
|
120 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
121 |
+
topk_probs, topk_indices = torch.topk(probs, top_k)
|
122 |
+
next_token = topk_indices[
|
123 |
+
torch.arange(probs.size(0)),
|
124 |
+
torch.multinomial(topk_probs, num_samples=1).squeeze(-1)
|
125 |
+
].unsqueeze(-1)
|
126 |
+
|
127 |
+
elif mode == "topp":
|
128 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
129 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
130 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
131 |
+
|
132 |
+
# Remove tokens with cumulative probs above threshold
|
133 |
+
sorted_mask = cumulative_probs <= top_p
|
134 |
+
sorted_mask[..., 0] = 1 # Always keep at least 1 token
|
135 |
+
|
136 |
+
filtered_probs = sorted_probs * sorted_mask
|
137 |
+
filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True)
|
138 |
+
|
139 |
+
next_token = sorted_indices[
|
140 |
+
torch.arange(probs.size(0)),
|
141 |
+
torch.multinomial(filtered_probs, num_samples=1).squeeze(-1)
|
142 |
+
].unsqueeze(-1)
|
143 |
+
|
144 |
+
else:
|
145 |
+
raise ValueError(f"Unknown mode: {mode}")
|
146 |
+
|
147 |
+
generated = torch.cat((generated, next_token), dim=1)
|
148 |
+
|
149 |
+
if eos_token_id is not None:
|
150 |
+
if (next_token == eos_token_id).all():
|
151 |
+
break
|
152 |
+
|
153 |
+
return generated[:, 1:] # Remove BOS if needed
|
154 |
+
|
155 |
+
def _generate_beam_search(
|
156 |
+
self,
|
157 |
+
encoder_outputs,
|
158 |
+
input_ids,
|
159 |
+
max_length=50,
|
160 |
+
eos_token_id=102,
|
161 |
+
num_beams=3,
|
162 |
+
length_penalty=1.0
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Custom beam search decoder for batch_size = 1.
|
166 |
+
"""
|
167 |
+
device = encoder_outputs.device
|
168 |
+
batch_size = encoder_outputs.size(0)
|
169 |
+
vocab_size = self.vocab_size
|
170 |
+
|
171 |
+
# Assume batch_size = 1 for simplicity
|
172 |
+
assert batch_size == 1, "Basic beam search only supports batch size 1 here."
|
173 |
+
|
174 |
+
# Initialize beams
|
175 |
+
beam_sequences = [input_ids] * num_beams
|
176 |
+
beam_scores = torch.zeros(num_beams, device=device)
|
177 |
+
|
178 |
+
finished_sequences = []
|
179 |
+
finished_scores = []
|
180 |
+
|
181 |
+
for step in range(max_length):
|
182 |
+
all_candidates = []
|
183 |
+
|
184 |
+
for beam_idx in range(num_beams):
|
185 |
+
seq = beam_sequences[beam_idx]
|
186 |
+
score = beam_scores[beam_idx]
|
187 |
+
|
188 |
+
logits = self.forward(seq, encoder_outputs) # (1, seq_len, vocab)
|
189 |
+
next_token_logits = logits[:, -1, :] # (1, vocab)
|
190 |
+
log_probs = F.log_softmax(next_token_logits, dim=-1).squeeze(0) # (vocab,)
|
191 |
+
|
192 |
+
for token_id in range(vocab_size):
|
193 |
+
new_seq = torch.cat([seq, torch.tensor([[token_id]], device=device)], dim=1)
|
194 |
+
new_score = score + log_probs[token_id]
|
195 |
+
all_candidates.append((new_seq, new_score))
|
196 |
+
|
197 |
+
# Get top beams
|
198 |
+
all_candidates.sort(key=lambda x: x[1], reverse=True)
|
199 |
+
beam_sequences = []
|
200 |
+
beam_scores = []
|
201 |
+
|
202 |
+
for seq, score in all_candidates[:num_beams]:
|
203 |
+
if eos_token_id is not None and seq[0, -1].item() == eos_token_id:
|
204 |
+
finished_sequences.append(seq)
|
205 |
+
finished_scores.append(score)
|
206 |
+
else:
|
207 |
+
beam_sequences.append(seq)
|
208 |
+
beam_scores.append(score)
|
209 |
+
|
210 |
+
beam_scores = torch.stack(beam_scores) if beam_scores else torch.tensor([], device=device)
|
211 |
+
|
212 |
+
# Early stopping if all beams ended
|
213 |
+
if len(beam_sequences) == 0:
|
214 |
+
break
|
215 |
+
|
216 |
+
# Add unfinished beams to finished
|
217 |
+
if not finished_sequences:
|
218 |
+
finished_sequences = beam_sequences
|
219 |
+
finished_scores = beam_scores
|
220 |
+
|
221 |
+
# Length penalty
|
222 |
+
finished_scores = [s / (len(seq[0]) ** length_penalty) for seq, s in zip(finished_sequences, finished_scores)]
|
223 |
+
|
224 |
+
# Pick best
|
225 |
+
best_idx = torch.tensor(finished_scores).argmax().item()
|
226 |
+
best_seq = finished_sequences[best_idx]
|
227 |
+
|
228 |
+
return best_seq[:, 1:] # remove BOS if needed
|
vit_captioning/models/encoder.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models/encoder.py
|
2 |
+
from transformers import ViTModel, ViTImageProcessor, CLIPModel
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
class ViTEncoder(nn.Module):
|
11 |
+
def __init__(self): # Make decoder_dim configurable!
|
12 |
+
super(ViTEncoder, self).__init__()
|
13 |
+
|
14 |
+
#weights = ViT_B_16_Weights.DEFAULT
|
15 |
+
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
16 |
+
|
17 |
+
def forward(self, pixel_values):
|
18 |
+
|
19 |
+
# ViTModel - output shape = [batch, seq_len, hidden]
|
20 |
+
outputs = self.vit(pixel_values=pixel_values)
|
21 |
+
|
22 |
+
# Take CLS: last_hidden_state
|
23 |
+
|
24 |
+
cls_embedding = outputs.last_hidden_state[:, 0]
|
25 |
+
return cls_embedding
|
26 |
+
|
27 |
+
# encoder.py
|
28 |
+
from transformers import CLIPModel
|
29 |
+
|
30 |
+
class CLIPEncoder(nn.Module):
|
31 |
+
def __init__(self):
|
32 |
+
super(CLIPEncoder, self).__init__()
|
33 |
+
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
34 |
+
|
35 |
+
def forward(self, pixel_values):
|
36 |
+
# β
Directly get the pooled image features (already the final representation)
|
37 |
+
image_features = self.clip.get_image_features(pixel_values=pixel_values)
|
38 |
+
return image_features # shape: [batch_size, hidden_dim]
|
vit_captioning/static/captioning/index.html
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8" />
|
5 |
+
<title>π€ Image πΌοΈ Captioning</title>
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
7 |
+
|
8 |
+
<!-- β
Tailwind CDN -->
|
9 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
10 |
+
</head>
|
11 |
+
|
12 |
+
<body class="bg-gray-100 flex items-center justify-center min-h-screen">
|
13 |
+
<a href="/" class="absolute top-4 left-4 text-blue-600 hover:text-blue-800 text-sm font-semibold flex items-center">
|
14 |
+
<!-- back arrow -->
|
15 |
+
<svg class="w-5 h-5 mr-1" fill="none" stroke="currentColor" stroke-width="2" viewBox="0 0 24 24"
|
16 |
+
xmlns="http://www.w3.org/2000/svg">
|
17 |
+
<path stroke-linecap="round" stroke-linejoin="round" d="M15 19l-7-7 7-7"></path>
|
18 |
+
</svg>
|
19 |
+
Back to Home
|
20 |
+
</a>
|
21 |
+
|
22 |
+
<div class="bg-white p-8 rounded-xl shadow-md w-full max-w-md text-center">
|
23 |
+
<h1 class="text-2xl font-bold mb-4 text-gray-800">AI Image Captioning</h1>
|
24 |
+
|
25 |
+
<!-- Upload Form -->
|
26 |
+
<form id="uploadForm" class="space-y-4">
|
27 |
+
<input
|
28 |
+
type="file"
|
29 |
+
id="fileInput"
|
30 |
+
accept="image/*"
|
31 |
+
required
|
32 |
+
class="block w-full text-sm text-gray-700 file:mr-4 file:py-2 file:px-4 file:rounded-full file:border-0 file:text-sm file:font-semibold file:bg-blue-50 file:text-blue-700 hover:file:bg-blue-100"
|
33 |
+
/>
|
34 |
+
|
35 |
+
<!-- Live Image Preview -->
|
36 |
+
<div id="previewContainer" class="mt-4 hidden">
|
37 |
+
<img id="previewImage" src="#" alt="Preview" class="mx-auto max-h-64 rounded-md shadow" />
|
38 |
+
</div>
|
39 |
+
|
40 |
+
<button
|
41 |
+
type="submit"
|
42 |
+
class="w-full bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2 px-4 rounded-lg transition"
|
43 |
+
>
|
44 |
+
Generate Captions
|
45 |
+
</button>
|
46 |
+
</form>
|
47 |
+
|
48 |
+
<!-- Captions -->
|
49 |
+
<div id="result" class="mt-6 text-left hidden">
|
50 |
+
<h2 class="text-lg font-semibold mb-2 text-gray-700">Captions:</h2>
|
51 |
+
<p><strong>Factual π€:</strong> <span id="greedy" class="text-gray-800"></span></p>
|
52 |
+
<p><strong>Creative π€ͺ:</strong> <span id="topk" class="text-gray-800"></span></p>
|
53 |
+
<p><strong>Human like π«:</strong> <span id="topp" class="text-gray-800"></span></p>
|
54 |
+
</div>
|
55 |
+
</div>
|
56 |
+
|
57 |
+
<script>
|
58 |
+
const fileInput = document.getElementById('fileInput');
|
59 |
+
const previewContainer = document.getElementById('previewContainer');
|
60 |
+
const previewImage = document.getElementById('previewImage');
|
61 |
+
const form = document.getElementById('uploadForm');
|
62 |
+
const result = document.getElementById('result');
|
63 |
+
|
64 |
+
// β
Live preview + clear old captions
|
65 |
+
fileInput.addEventListener('change', () => {
|
66 |
+
const file = fileInput.files[0];
|
67 |
+
if (file) {
|
68 |
+
const reader = new FileReader();
|
69 |
+
reader.onload = e => {
|
70 |
+
previewImage.src = e.target.result;
|
71 |
+
previewContainer.classList.remove('hidden');
|
72 |
+
};
|
73 |
+
reader.readAsDataURL(file);
|
74 |
+
|
75 |
+
// Clear old captions
|
76 |
+
document.getElementById('greedy').innerText = "";
|
77 |
+
document.getElementById('topk').innerText = "";
|
78 |
+
document.getElementById('topp').innerText = "";
|
79 |
+
result.classList.add('hidden');
|
80 |
+
|
81 |
+
} else {
|
82 |
+
previewContainer.classList.add('hidden');
|
83 |
+
}
|
84 |
+
});
|
85 |
+
|
86 |
+
// β
Submit form
|
87 |
+
form.addEventListener('submit', async e => {
|
88 |
+
e.preventDefault();
|
89 |
+
const file = fileInput.files[0];
|
90 |
+
const formData = new FormData();
|
91 |
+
formData.append('file', file);
|
92 |
+
|
93 |
+
const res = await fetch('/generate', {
|
94 |
+
method: 'POST',
|
95 |
+
body: formData
|
96 |
+
});
|
97 |
+
|
98 |
+
const data = await res.json();
|
99 |
+
document.getElementById('greedy').innerText = data.greedy || "N/A";
|
100 |
+
document.getElementById('topk').innerText = data.topk || "N/A";
|
101 |
+
document.getElementById('topp').innerText = data.topp || "N/A";
|
102 |
+
result.classList.remove('hidden');
|
103 |
+
});
|
104 |
+
</script>
|
105 |
+
|
106 |
+
<!-- Floating Help Button -->
|
107 |
+
<button id="helpButton"
|
108 |
+
class="fixed bottom-4 right-4 bg-blue-600 text-white rounded-full w-12 h-12 text-2xl font-bold shadow-lg hover:bg-blue-700 transition">
|
109 |
+
?
|
110 |
+
</button>
|
111 |
+
|
112 |
+
<!-- Help Modal -->
|
113 |
+
<div id="helpModal" class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center hidden">
|
114 |
+
<div class="bg-white rounded-lg p-6 max-w-sm w-full shadow-lg text-left">
|
115 |
+
<h2 class="text-xl font-semibold mb-4">π€ Image Captioning</h2>
|
116 |
+
<p class="text-gray-700 mb-4">
|
117 |
+
Please upload a picture / image and press "Generate Captions", the model will generate captions for it.
|
118 |
+
The model uses google/vit-base-patch16-224-in21k or openai/clip-vit-base-patch32
|
119 |
+
as image encoder, trained together with a customer transformer decoder to generate captions.<br>
|
120 |
+
The available caption styles are: "Factual π€", "Creative π€ͺ", and "Human like π«",
|
121 |
+
which are actually argmax (greedy), top-K and top-P respectively.
|
122 |
+
|
123 |
+
</p>
|
124 |
+
<button id="closeModal"
|
125 |
+
class="mt-2 bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700">
|
126 |
+
Close
|
127 |
+
</button>
|
128 |
+
</div>
|
129 |
+
</div>
|
130 |
+
|
131 |
+
<script>
|
132 |
+
const helpButton = document.getElementById('helpButton');
|
133 |
+
const helpModal = document.getElementById('helpModal');
|
134 |
+
const closeModal = document.getElementById('closeModal');
|
135 |
+
|
136 |
+
helpButton.addEventListener('click', () => {
|
137 |
+
helpModal.classList.remove('hidden');
|
138 |
+
});
|
139 |
+
|
140 |
+
closeModal.addEventListener('click', () => {
|
141 |
+
helpModal.classList.add('hidden');
|
142 |
+
});
|
143 |
+
|
144 |
+
// Optional: close modal when clicking outside the modal box
|
145 |
+
helpModal.addEventListener('click', (e) => {
|
146 |
+
if (e.target === helpModal) {
|
147 |
+
helpModal.classList.add('hidden');
|
148 |
+
}
|
149 |
+
});
|
150 |
+
</script>
|
151 |
+
|
152 |
+
</body>
|
153 |
+
</html>
|
vit_captioning/static/landing.html
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8" />
|
5 |
+
<title>Clement's AI Lab</title>
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
8 |
+
</head>
|
9 |
+
<body class="bg-gray-100 flex items-center justify-center min-h-screen">
|
10 |
+
<div class="max-w-md w-full p-6 text-center">
|
11 |
+
<h1 class="text-3xl font-bold mb-6 text-gray-800">π Clement's AI Lab</h1>
|
12 |
+
<div class="space-y-4">
|
13 |
+
<a href="/captioning" class="block w-full bg-blue-600 hover:bg-blue-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
|
14 |
+
πΌοΈ Image Captioning
|
15 |
+
</a>
|
16 |
+
<a href="/project2" class="block w-full bg-green-600 hover:bg-green-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
|
17 |
+
π€ Coming Soon: Word calculator
|
18 |
+
</a>
|
19 |
+
<!-- Add more project links here -->
|
20 |
+
</div>
|
21 |
+
</div>
|
22 |
+
|
23 |
+
<!-- Floating Help Button -->
|
24 |
+
<button id="helpButton"
|
25 |
+
class="fixed bottom-4 right-4 bg-blue-600 text-white rounded-full w-12 h-12 text-2xl font-bold shadow-lg hover:bg-blue-700 transition">
|
26 |
+
?
|
27 |
+
</button>
|
28 |
+
|
29 |
+
<!-- Help Modal -->
|
30 |
+
<div id="helpModal" class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center hidden">
|
31 |
+
<div class="bg-white rounded-lg p-6 max-w-sm w-full shadow-lg text-left">
|
32 |
+
<h2 class="text-xl font-semibold mb-4">Clement's AI Lab</h2>
|
33 |
+
<p class="text-gray-700 mb-4">
|
34 |
+
Welcome! I'm Clement, and I've built these interactive models for you to experiment with. Whether you're curious about AI or just want to have some fun, there's something here for everyone.<!-- Page-specific explanation goes here -->
|
35 |
+
</p>
|
36 |
+
<button id="closeModal"
|
37 |
+
class="mt-2 bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700">
|
38 |
+
Close
|
39 |
+
</button>
|
40 |
+
</div>
|
41 |
+
</div>
|
42 |
+
|
43 |
+
<script>
|
44 |
+
const helpButton = document.getElementById('helpButton');
|
45 |
+
const helpModal = document.getElementById('helpModal');
|
46 |
+
const closeModal = document.getElementById('closeModal');
|
47 |
+
|
48 |
+
helpButton.addEventListener('click', () => {
|
49 |
+
helpModal.classList.remove('hidden');
|
50 |
+
});
|
51 |
+
|
52 |
+
closeModal.addEventListener('click', () => {
|
53 |
+
helpModal.classList.add('hidden');
|
54 |
+
});
|
55 |
+
|
56 |
+
// Optional: close modal when clicking outside the modal box
|
57 |
+
helpModal.addEventListener('click', (e) => {
|
58 |
+
if (e.target === helpModal) {
|
59 |
+
helpModal.classList.add('hidden');
|
60 |
+
}
|
61 |
+
});
|
62 |
+
</script>
|
63 |
+
|
64 |
+
</body>
|
65 |
+
</html>
|