Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
Β·
79a071f
1
Parent(s):
3aa3d4c
Added two more games. Container runs locally
Browse files- .dockerignore +1 -0
- .gitignore +2 -1
- Dockerfile +8 -1
- cbow_logic.py +52 -0
- content_distillery/static/content_distillery.html +120 -0
- main.py +84 -18
- ppo_summarizer/predict_ppo.py β ppo_logic.py +14 -12
- qwen_loRA/README.md +207 -0
- qwen_loRA/adapter_config.json +36 -0
- qwen_loRA/adapter_model.safetensors +3 -0
- qwen_loRA/reward_head.pt +3 -0
- requirements.txt +3 -0
- save_cbow_model.py +16 -0
- templates/cbow.html +248 -0
- vit_captioning/generate.py +12 -7
- vit_captioning/models/encoder.py +1 -6
- vit_captioning/static/landing.html +4 -4
.dockerignore
CHANGED
@@ -27,6 +27,7 @@ clip-checkpoints/
|
|
27 |
*.pt
|
28 |
*.pth
|
29 |
*.onnx
|
|
|
30 |
|
31 |
# Docker or Space-specific
|
32 |
docker-compose.yaml
|
|
|
27 |
*.pt
|
28 |
*.pth
|
29 |
*.onnx
|
30 |
+
models/
|
31 |
|
32 |
# Docker or Space-specific
|
33 |
docker-compose.yaml
|
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
__pycache__/
|
2 |
*.png
|
3 |
-
**/artifacts/
|
|
|
|
1 |
__pycache__/
|
2 |
*.png
|
3 |
+
**/artifacts/
|
4 |
+
models/
|
Dockerfile
CHANGED
@@ -19,9 +19,16 @@ RUN pip install -r requirements.txt
|
|
19 |
|
20 |
RUN mkdir -p /models/clip && \
|
21 |
python3 -c "from transformers import CLIPModel; CLIPModel.from_pretrained('openai/clip-vit-base-patch32').save_pretrained('/models/clip')"
|
22 |
-
|
23 |
RUN python3 -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('bert-base-uncased').save_pretrained('/models/bert-tokenizer')"
|
24 |
RUN python3 -c "from transformers import CLIPProcessor; CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32').save_pretrained('/models/clip')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
EXPOSE 7860
|
26 |
|
27 |
# Install curl if it's not already installed
|
|
|
19 |
|
20 |
RUN mkdir -p /models/clip && \
|
21 |
python3 -c "from transformers import CLIPModel; CLIPModel.from_pretrained('openai/clip-vit-base-patch32').save_pretrained('/models/clip')"
|
|
|
22 |
RUN python3 -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('bert-base-uncased').save_pretrained('/models/bert-tokenizer')"
|
23 |
RUN python3 -c "from transformers import CLIPProcessor; CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32').save_pretrained('/models/clip')"
|
24 |
+
|
25 |
+
RUN mkdir -p /models/cbow && \
|
26 |
+
python3 -c "import gensim.downloader as api; model = api.load('glove-twitter-200'); model.save('/models/cbow_model.kv')"
|
27 |
+
|
28 |
+
RUN mkdir -p /models/qwen && \
|
29 |
+
python3 -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('Qwen/Qwen3-0.6B-Base').save_pretrained('/models/qwen')"
|
30 |
+
RUN python3 -c "from transformers import AutoModelForCausalLM; AutoModelForCausalLM.from_pretrained('Qwen/Qwen3-0.6B-Base').save_pretrained('/models/qwen')"
|
31 |
+
|
32 |
EXPOSE 7860
|
33 |
|
34 |
# Install curl if it's not already installed
|
cbow_logic.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cbow_logic.py
|
2 |
+
import gensim
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from typing import List, Tuple
|
6 |
+
import shlex
|
7 |
+
|
8 |
+
|
9 |
+
class MeaningCalculator:
|
10 |
+
def __init__(self, model_path: str = "/models/cbow_model.kv"):
|
11 |
+
if not os.path.exists(model_path):
|
12 |
+
raise FileNotFoundError(f"Model not found at: {model_path}")
|
13 |
+
self.model = gensim.models.KeyedVectors.load(model_path, mmap='r')
|
14 |
+
|
15 |
+
def evaluate_expression(self, expression: str, topn: int = 10) -> List[Tuple[str, float]]:
|
16 |
+
# Evaluate expressions like '"new york" - city + capital'.
|
17 |
+
tokens = shlex.split(expression) # Handles quoted terms properly
|
18 |
+
positive = []
|
19 |
+
negative = []
|
20 |
+
current_op = "+"
|
21 |
+
|
22 |
+
for token in tokens:
|
23 |
+
print(token)
|
24 |
+
if token in ["+", "-"]:
|
25 |
+
current_op = token
|
26 |
+
else:
|
27 |
+
if current_op == "+":
|
28 |
+
positive.append(token)
|
29 |
+
else:
|
30 |
+
negative.append(token)
|
31 |
+
|
32 |
+
try:
|
33 |
+
return self.model.most_similar(positive=positive, negative=negative, topn=topn)
|
34 |
+
except KeyError as e:
|
35 |
+
return [("InputError", 0.0)]
|
36 |
+
|
37 |
+
from gensim.models import KeyedVectors
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
parser = argparse.ArgumentParser(description="Evaluate word vector expressions using CBOW.")
|
43 |
+
parser.add_argument("expression", type=str, help="Expression like 'king - man + woman'")
|
44 |
+
parser.add_argument("--model_path", type=str, default="./models/cbow_model.kv", help="Path to CBOW model")
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
calc = MeaningCalculator(model_path=args.model_path)
|
48 |
+
results = calc.evaluate_expression(args.expression)
|
49 |
+
|
50 |
+
print(f"\nExpression: {args.expression}\nTop Results:")
|
51 |
+
for word, score in results:
|
52 |
+
print(f" {word:<15} {score:.4f}")
|
content_distillery/static/content_distillery.html
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8" />
|
5 |
+
<title>π Content Distillery</title>
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
8 |
+
</head>
|
9 |
+
<body class="bg-gray-100 min-h-screen p-6 flex flex-col items-center">
|
10 |
+
|
11 |
+
<div class="max-w-2xl w-full space-y-4">
|
12 |
+
<h1 class="text-3xl font-bold text-gray-800 text-center">π Content Distillery</h1>
|
13 |
+
|
14 |
+
<!-- Source Dropdown with Refresh Button -->
|
15 |
+
<label class="block font-semibold text-gray-700">Choose a Source:</label>
|
16 |
+
<div class="flex items-center space-x-2">
|
17 |
+
<select id="sourceSelect" class="flex-grow p-2 border rounded">
|
18 |
+
<option value="">-- Select Source --</option>
|
19 |
+
<option value="reddit_romance">Romantic Relationship</option>
|
20 |
+
<option value="reddit_aita">Am I The Asshole</option>
|
21 |
+
<option value="reddit_careers">Career Guidance</option>
|
22 |
+
<option value="reddit_cars">Car discussion</option>
|
23 |
+
<option value="reddit_whatcarshouldibuy">What Car Should I Buy</option>
|
24 |
+
<option value="reddit_nosleep">Horror stories</option>
|
25 |
+
<option value="reddit_maliciouscompliance">People following bad instructions exactly</option>
|
26 |
+
<option value="reddit_talesfromtechsupport">Tech support stories</option>
|
27 |
+
<option value="reddit_decidingtobebetter">Self-improvement and habit change</option>
|
28 |
+
<option value="reddit_askphilosophy">Big-life questions</option>
|
29 |
+
</select>
|
30 |
+
<button id="refreshPost" title="Get another post"
|
31 |
+
class="p-2 bg-gray-200 hover:bg-gray-300 rounded text-lg">π</button>
|
32 |
+
</div>
|
33 |
+
|
34 |
+
<!-- Input Text -->
|
35 |
+
<label class="block font-semibold text-gray-700 mt-4">Original Text: (can modify or create your own)</label>
|
36 |
+
<textarea id="inputText" rows="8" class="w-full p-3 border rounded" placeholder="Paste or fetch post text here..."></textarea>
|
37 |
+
|
38 |
+
<!-- Generate Button -->
|
39 |
+
<button id="generateBtn" class="w-full bg-red-600 text-white py-2 rounded hover:bg-red-700 font-semibold">
|
40 |
+
π€ Generate Summary
|
41 |
+
</button>
|
42 |
+
|
43 |
+
<!-- Output Text -->
|
44 |
+
<label class="block font-semibold text-gray-700 mt-4">Summary:</label>
|
45 |
+
<textarea id="outputText" rows="5" class="w-full p-3 border rounded bg-gray-50" readonly placeholder="Summary will appear here..."></textarea>
|
46 |
+
</div>
|
47 |
+
|
48 |
+
<!-- Floating Help Button -->
|
49 |
+
<button id="helpButton"
|
50 |
+
class="fixed bottom-4 right-4 bg-blue-600 text-white rounded-full w-12 h-12 text-2xl font-bold shadow-lg hover:bg-blue-700 transition">
|
51 |
+
?
|
52 |
+
</button>
|
53 |
+
|
54 |
+
<!-- Help Modal -->
|
55 |
+
<div id="helpModal" class="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center hidden">
|
56 |
+
<div class="bg-white rounded-lg p-6 max-w-sm w-full shadow-lg text-left">
|
57 |
+
<h2 class="text-xl font-semibold mb-4">About Content Distillery</h2>
|
58 |
+
<p class="text-gray-700 mb-4">
|
59 |
+
This tool fetches content from various online sources and distills it into concise summaries using an a PPO model.
|
60 |
+
Choose a source, optionally edit the text, and press "Generate" to see a summary.
|
61 |
+
The AI reward and PPO models were from Qwen/Qwen3-0.6B-Base, used openai/summarize_from_feedback dataset and trained according to OpenAI paper "Learning to summarize from human feedback".
|
62 |
+
</p>
|
63 |
+
<button id="closeModal"
|
64 |
+
class="mt-2 bg-blue-600 text-white px-4 py-2 rounded hover:bg-blue-700">
|
65 |
+
Close
|
66 |
+
</button>
|
67 |
+
</div>
|
68 |
+
</div>
|
69 |
+
|
70 |
+
<script>
|
71 |
+
const helpButton = document.getElementById('helpButton');
|
72 |
+
const helpModal = document.getElementById('helpModal');
|
73 |
+
const closeModal = document.getElementById('closeModal');
|
74 |
+
|
75 |
+
helpButton.addEventListener('click', () => helpModal.classList.remove('hidden'));
|
76 |
+
closeModal.addEventListener('click', () => helpModal.classList.add('hidden'));
|
77 |
+
helpModal.addEventListener('click', (e) => {
|
78 |
+
if (e.target === helpModal) helpModal.classList.add('hidden');
|
79 |
+
});
|
80 |
+
|
81 |
+
async function fetchPost() {
|
82 |
+
const selected = document.getElementById("sourceSelect").value;
|
83 |
+
if (!selected) return;
|
84 |
+
document.getElementById("inputText").value = "Fetching post...";
|
85 |
+
const res = await fetch(`/get_sample?source=${selected}`);
|
86 |
+
const data = await res.text();
|
87 |
+
document.getElementById("inputText").value = data;
|
88 |
+
}
|
89 |
+
|
90 |
+
document.getElementById("sourceSelect").addEventListener("change", fetchPost);
|
91 |
+
document.getElementById("refreshPost").addEventListener("click", fetchPost);
|
92 |
+
|
93 |
+
document.getElementById('generateBtn').addEventListener('click', async function () {
|
94 |
+
const btn = document.getElementById('generateBtn');
|
95 |
+
const post = document.getElementById('inputText').value;
|
96 |
+
|
97 |
+
btn.disabled = true;
|
98 |
+
btn.textContent = "Generating...";
|
99 |
+
btn.classList.add("opacity-50");
|
100 |
+
|
101 |
+
try {
|
102 |
+
const res = await fetch('/contentdistillery', {
|
103 |
+
method: 'POST',
|
104 |
+
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
105 |
+
body: new URLSearchParams({ post })
|
106 |
+
});
|
107 |
+
const summary = await res.text();
|
108 |
+
document.getElementById('outputText').value = summary;
|
109 |
+
} catch (e) {
|
110 |
+
document.getElementById('outputText').value = "Error generating summary.";
|
111 |
+
}
|
112 |
+
|
113 |
+
btn.disabled = false;
|
114 |
+
btn.textContent = "Generate Summary";
|
115 |
+
btn.classList.remove("opacity-50");
|
116 |
+
});
|
117 |
+
</script>
|
118 |
+
|
119 |
+
</body>
|
120 |
+
</html>
|
main.py
CHANGED
@@ -1,16 +1,34 @@
|
|
1 |
# app/main.py
|
2 |
|
3 |
-
from fastapi import FastAPI, UploadFile, File
|
4 |
-
from fastapi.responses import HTMLResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import shutil
|
7 |
from pathlib import Path
|
8 |
import uvicorn
|
9 |
import os
|
|
|
|
|
10 |
|
11 |
from vit_captioning.generate import CaptionGenerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
app = FastAPI()
|
|
|
|
|
|
|
14 |
|
15 |
# Serve static files
|
16 |
static_dir = Path(__file__).parent / "vit_captioning" / "static"
|
@@ -21,10 +39,6 @@ app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
21 |
async def landing():
|
22 |
return Path("vit_captioning/static/landing.html").read_text()
|
23 |
|
24 |
-
# @app.get("/", response_class=HTMLResponse)
|
25 |
-
# def root():
|
26 |
-
# return "<h3>β
Hugging Face Space is alive</h3>"
|
27 |
-
|
28 |
@app.get("/health")
|
29 |
def health_check():
|
30 |
return {"status": "ok"}
|
@@ -34,22 +48,17 @@ def health_check():
|
|
34 |
async def captioning():
|
35 |
return Path("vit_captioning/static/captioning/index.html").read_text()
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
return "<h1>Coming Soon: Project 2</h1>"
|
41 |
-
|
42 |
-
# β
Example: Project 2 placeholder
|
43 |
-
@app.get("/project3", response_class=HTMLResponse)
|
44 |
-
async def project2():
|
45 |
-
return "<h1>Coming Soon: Project 3</h1>"
|
46 |
|
47 |
# β
Caption generation endpoint for captioning app
|
48 |
# Keep the path consistent with your JS fetch()!
|
49 |
caption_generator = CaptionGenerator(
|
50 |
model_type="CLIPEncoder",
|
51 |
checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
|
52 |
-
quantized=False
|
|
|
53 |
)
|
54 |
|
55 |
@app.post("/generate")
|
@@ -61,5 +70,62 @@ async def generate(file: UploadFile = File(...)):
|
|
61 |
captions = caption_generator.generate_caption(temp_file)
|
62 |
return captions
|
63 |
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# app/main.py
|
2 |
|
3 |
+
from fastapi import FastAPI, UploadFile, File, Request, Form, Query
|
4 |
+
from fastapi.responses import HTMLResponse, PlainTextResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
6 |
+
from fastapi.templating import Jinja2Templates
|
7 |
+
from cbow_logic import MeaningCalculator
|
8 |
+
from ppo_logic import generate_summary
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
import shutil
|
13 |
from pathlib import Path
|
14 |
import uvicorn
|
15 |
import os
|
16 |
+
import praw
|
17 |
+
import random
|
18 |
|
19 |
from vit_captioning.generate import CaptionGenerator
|
20 |
+
from cbow_logic import MeaningCalculator
|
21 |
+
|
22 |
+
reddit = praw.Reddit(
|
23 |
+
client_id="geuNJZLDwSCdz7sV5vkDNQ",
|
24 |
+
client_secret="IFz7zPVGP3hO6VMy1YU1WX_bX3FpfQ",
|
25 |
+
user_agent="ContentDistilleryBot/0.1 by ClementHa"
|
26 |
+
)
|
27 |
|
28 |
app = FastAPI()
|
29 |
+
templates = Jinja2Templates(directory="templates")
|
30 |
+
calculator = MeaningCalculator()
|
31 |
+
|
32 |
|
33 |
# Serve static files
|
34 |
static_dir = Path(__file__).parent / "vit_captioning" / "static"
|
|
|
39 |
async def landing():
|
40 |
return Path("vit_captioning/static/landing.html").read_text()
|
41 |
|
|
|
|
|
|
|
|
|
42 |
@app.get("/health")
|
43 |
def health_check():
|
44 |
return {"status": "ok"}
|
|
|
48 |
async def captioning():
|
49 |
return Path("vit_captioning/static/captioning/index.html").read_text()
|
50 |
|
51 |
+
@app.get("/contentdistillery", response_class=HTMLResponse)
|
52 |
+
async def contentdistillery():
|
53 |
+
return Path("content_distillery/static/content_distillery.html").read_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
# β
Caption generation endpoint for captioning app
|
56 |
# Keep the path consistent with your JS fetch()!
|
57 |
caption_generator = CaptionGenerator(
|
58 |
model_type="CLIPEncoder",
|
59 |
checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
|
60 |
+
quantized=False,
|
61 |
+
runAsContainer=False
|
62 |
)
|
63 |
|
64 |
@app.post("/generate")
|
|
|
70 |
captions = caption_generator.generate_caption(temp_file)
|
71 |
return captions
|
72 |
|
73 |
+
@app.get("/cbow", response_class=HTMLResponse)
|
74 |
+
async def cbow_form(request: Request):
|
75 |
+
return templates.TemplateResponse("cbow.html", {"request": request})
|
76 |
+
|
77 |
+
@app.post("/cbow")
|
78 |
+
async def cbow(request: Request, expression: str = Form(...)):
|
79 |
+
expression = expression.lower()
|
80 |
+
results = MeaningCalculator().evaluate_expression(expression = expression)
|
81 |
+
# formatted = [
|
82 |
+
# (word, f"{score:.2f}" if score >= 0.4 else "Irrelevant result")
|
83 |
+
# for word, score in results[:5]
|
84 |
+
# ]
|
85 |
+
return templates.TemplateResponse("cbow.html", {
|
86 |
+
"request": request,
|
87 |
+
"expression": expression,
|
88 |
+
"results": results
|
89 |
+
})
|
90 |
+
|
91 |
+
@app.get("/contentdistillery", response_class=HTMLResponse)
|
92 |
+
async def contentdistillery_page():
|
93 |
+
return Path("contentdistillery.html").read_text(encoding="utf-8")
|
94 |
+
|
95 |
+
@app.post("/contentdistillery", response_class=PlainTextResponse)
|
96 |
+
async def generate_summary_from_post(post: str = Form(...)):
|
97 |
+
return generate_summary(post)
|
98 |
+
|
99 |
+
@app.get("/get_sample", response_class=PlainTextResponse)
|
100 |
+
def get_sample(source: str = Query(...)):
|
101 |
+
try:
|
102 |
+
if source == "reddit_romance":
|
103 |
+
submissions = reddit.subreddit("relationships").top(limit=10)
|
104 |
+
elif source == "reddit_aita":
|
105 |
+
submissions = reddit.subreddit("AmItheAsshole").hot(limit=10)
|
106 |
+
elif source == "reddit_careers":
|
107 |
+
submissions = reddit.subreddit("careerguidance").hot(limit=10)
|
108 |
+
elif source == "reddit_cars":
|
109 |
+
submissions = reddit.subreddit("cars").hot(limit=10)
|
110 |
+
elif source == "reddit_whatcarshouldibuy":
|
111 |
+
submissions = reddit.subreddit("whatcarshouldibuy").top(limit=10)
|
112 |
+
elif source == "reddit_nosleep":
|
113 |
+
submissions = reddit.subreddit("nosleep").top(limit=10)
|
114 |
+
elif source == "reddit_maliciouscompliance":
|
115 |
+
submissions = reddit.subreddit("maliciouscompliance").hot(limit=10)
|
116 |
+
elif source == "reddit_talesfromtechsupport":
|
117 |
+
submissions = reddit.subreddit("talesfromtechsupport").top(limit=10)
|
118 |
+
elif source == "reddit_decidingtobebetter":
|
119 |
+
submissions = reddit.subreddit("decidingtobebetter").hot(limit=10)
|
120 |
+
elif source == "reddit_askphilosophy":
|
121 |
+
submissions = reddit.subreddit("askphilosophy").top(limit=10)
|
122 |
+
else:
|
123 |
+
return "Unsupported source."
|
124 |
+
|
125 |
+
posts = [s.selftext.strip() for s in submissions if s.selftext.strip()]
|
126 |
+
if posts:
|
127 |
+
return random.choice(posts)
|
128 |
+
return "No suitable post found."
|
129 |
+
|
130 |
+
except Exception as e:
|
131 |
+
return f"Error fetching Reddit post: {str(e)}"
|
ppo_summarizer/predict_ppo.py β ppo_logic.py
RENAMED
@@ -9,10 +9,11 @@ import os
|
|
9 |
# -------------------------------
|
10 |
# Config
|
11 |
# -------------------------------
|
12 |
-
MODEL_NAME = "Qwen/Qwen3-0.6B-Base"
|
13 |
-
|
|
|
14 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
-
MAX_NEW_TOKENS =
|
16 |
|
17 |
# -------------------------------
|
18 |
# Load tokenizer and model
|
@@ -34,18 +35,21 @@ model = model.to(DEVICE)
|
|
34 |
# -------------------------------
|
35 |
# Generate Summary
|
36 |
# -------------------------------
|
37 |
-
def generate_summary(
|
38 |
-
prompt = f"
|
|
|
|
|
|
|
39 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
40 |
|
41 |
with torch.no_grad():
|
42 |
outputs = model.generate(
|
43 |
**inputs,
|
44 |
max_new_tokens=MAX_NEW_TOKENS,
|
45 |
-
do_sample=
|
46 |
-
top_k=50,
|
47 |
-
top_p=0.95,
|
48 |
-
temperature=0
|
49 |
pad_token_id=tokenizer.pad_token_id,
|
50 |
use_cache=True
|
51 |
)
|
@@ -59,13 +63,11 @@ def generate_summary(title: str, post: str) -> str:
|
|
59 |
# -------------------------------
|
60 |
if __name__ == "__main__":
|
61 |
parser = argparse.ArgumentParser(description="Generate summary with trained Qwen PPO model")
|
62 |
-
parser.add_argument("--title", type=str, required=True, help="Title of the post")
|
63 |
parser.add_argument("--post", type=str, required=True, help="Content of the post")
|
64 |
args = parser.parse_args()
|
65 |
|
66 |
-
print("\nπ Title:", args.title)
|
67 |
print("π Post:", args.post[:100] + ("..." if len(args.post) > 100 else ""))
|
68 |
print("\nπ€ Generating summary...\n")
|
69 |
|
70 |
-
summary = generate_summary(args.
|
71 |
print("β
Summary:\n", summary)
|
|
|
9 |
# -------------------------------
|
10 |
# Config
|
11 |
# -------------------------------
|
12 |
+
#MODEL_NAME = "Qwen/Qwen3-0.6B-Base"
|
13 |
+
MODEL_NAME = "/models/qwen"
|
14 |
+
CHECKPOINT_DIR = "./qwen_loRA"
|
15 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
MAX_NEW_TOKENS = 256
|
17 |
|
18 |
# -------------------------------
|
19 |
# Load tokenizer and model
|
|
|
35 |
# -------------------------------
|
36 |
# Generate Summary
|
37 |
# -------------------------------
|
38 |
+
def generate_summary(post: str) -> str:
|
39 |
+
#prompt = f"Instruction: Summarize the post in one sentence.\n\nPost:\n{post}\n\nSummary:"
|
40 |
+
# prompt = f"Please summarize the following Reddit post in 1β2 sentences:\n\n{post}\n\nSummary:"
|
41 |
+
prompt = f"Instruction: Summarize the post in 1-2 sentences.\n\nPost:\n{post}\n\nSummary:"
|
42 |
+
|
43 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
44 |
|
45 |
with torch.no_grad():
|
46 |
outputs = model.generate(
|
47 |
**inputs,
|
48 |
max_new_tokens=MAX_NEW_TOKENS,
|
49 |
+
do_sample=False,
|
50 |
+
# top_k=50,
|
51 |
+
# top_p=0.95,
|
52 |
+
temperature=1.0,
|
53 |
pad_token_id=tokenizer.pad_token_id,
|
54 |
use_cache=True
|
55 |
)
|
|
|
63 |
# -------------------------------
|
64 |
if __name__ == "__main__":
|
65 |
parser = argparse.ArgumentParser(description="Generate summary with trained Qwen PPO model")
|
|
|
66 |
parser.add_argument("--post", type=str, required=True, help="Content of the post")
|
67 |
args = parser.parse_args()
|
68 |
|
|
|
69 |
print("π Post:", args.post[:100] + ("..." if len(args.post) > 100 else ""))
|
70 |
print("\nπ€ Generating summary...\n")
|
71 |
|
72 |
+
summary = generate_summary(args.post)
|
73 |
print("β
Summary:\n", summary)
|
qwen_loRA/README.md
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
base_model: Qwen/Qwen3-0.6B-Base
|
3 |
+
library_name: peft
|
4 |
+
pipeline_tag: text-generation
|
5 |
+
tags:
|
6 |
+
- base_model:adapter:Qwen/Qwen3-0.6B-Base
|
7 |
+
- lora
|
8 |
+
- transformers
|
9 |
+
---
|
10 |
+
|
11 |
+
# Model Card for Model ID
|
12 |
+
|
13 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
## Model Details
|
18 |
+
|
19 |
+
### Model Description
|
20 |
+
|
21 |
+
<!-- Provide a longer summary of what this model is. -->
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
- **Developed by:** [More Information Needed]
|
26 |
+
- **Funded by [optional]:** [More Information Needed]
|
27 |
+
- **Shared by [optional]:** [More Information Needed]
|
28 |
+
- **Model type:** [More Information Needed]
|
29 |
+
- **Language(s) (NLP):** [More Information Needed]
|
30 |
+
- **License:** [More Information Needed]
|
31 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
32 |
+
|
33 |
+
### Model Sources [optional]
|
34 |
+
|
35 |
+
<!-- Provide the basic links for the model. -->
|
36 |
+
|
37 |
+
- **Repository:** [More Information Needed]
|
38 |
+
- **Paper [optional]:** [More Information Needed]
|
39 |
+
- **Demo [optional]:** [More Information Needed]
|
40 |
+
|
41 |
+
## Uses
|
42 |
+
|
43 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
44 |
+
|
45 |
+
### Direct Use
|
46 |
+
|
47 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
48 |
+
|
49 |
+
[More Information Needed]
|
50 |
+
|
51 |
+
### Downstream Use [optional]
|
52 |
+
|
53 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
54 |
+
|
55 |
+
[More Information Needed]
|
56 |
+
|
57 |
+
### Out-of-Scope Use
|
58 |
+
|
59 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
60 |
+
|
61 |
+
[More Information Needed]
|
62 |
+
|
63 |
+
## Bias, Risks, and Limitations
|
64 |
+
|
65 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
66 |
+
|
67 |
+
[More Information Needed]
|
68 |
+
|
69 |
+
### Recommendations
|
70 |
+
|
71 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
72 |
+
|
73 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
74 |
+
|
75 |
+
## How to Get Started with the Model
|
76 |
+
|
77 |
+
Use the code below to get started with the model.
|
78 |
+
|
79 |
+
[More Information Needed]
|
80 |
+
|
81 |
+
## Training Details
|
82 |
+
|
83 |
+
### Training Data
|
84 |
+
|
85 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
86 |
+
|
87 |
+
[More Information Needed]
|
88 |
+
|
89 |
+
### Training Procedure
|
90 |
+
|
91 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
92 |
+
|
93 |
+
#### Preprocessing [optional]
|
94 |
+
|
95 |
+
[More Information Needed]
|
96 |
+
|
97 |
+
|
98 |
+
#### Training Hyperparameters
|
99 |
+
|
100 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
101 |
+
|
102 |
+
#### Speeds, Sizes, Times [optional]
|
103 |
+
|
104 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
105 |
+
|
106 |
+
[More Information Needed]
|
107 |
+
|
108 |
+
## Evaluation
|
109 |
+
|
110 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
111 |
+
|
112 |
+
### Testing Data, Factors & Metrics
|
113 |
+
|
114 |
+
#### Testing Data
|
115 |
+
|
116 |
+
<!-- This should link to a Dataset Card if possible. -->
|
117 |
+
|
118 |
+
[More Information Needed]
|
119 |
+
|
120 |
+
#### Factors
|
121 |
+
|
122 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
123 |
+
|
124 |
+
[More Information Needed]
|
125 |
+
|
126 |
+
#### Metrics
|
127 |
+
|
128 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
129 |
+
|
130 |
+
[More Information Needed]
|
131 |
+
|
132 |
+
### Results
|
133 |
+
|
134 |
+
[More Information Needed]
|
135 |
+
|
136 |
+
#### Summary
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
## Model Examination [optional]
|
141 |
+
|
142 |
+
<!-- Relevant interpretability work for the model goes here -->
|
143 |
+
|
144 |
+
[More Information Needed]
|
145 |
+
|
146 |
+
## Environmental Impact
|
147 |
+
|
148 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
149 |
+
|
150 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
151 |
+
|
152 |
+
- **Hardware Type:** [More Information Needed]
|
153 |
+
- **Hours used:** [More Information Needed]
|
154 |
+
- **Cloud Provider:** [More Information Needed]
|
155 |
+
- **Compute Region:** [More Information Needed]
|
156 |
+
- **Carbon Emitted:** [More Information Needed]
|
157 |
+
|
158 |
+
## Technical Specifications [optional]
|
159 |
+
|
160 |
+
### Model Architecture and Objective
|
161 |
+
|
162 |
+
[More Information Needed]
|
163 |
+
|
164 |
+
### Compute Infrastructure
|
165 |
+
|
166 |
+
[More Information Needed]
|
167 |
+
|
168 |
+
#### Hardware
|
169 |
+
|
170 |
+
[More Information Needed]
|
171 |
+
|
172 |
+
#### Software
|
173 |
+
|
174 |
+
[More Information Needed]
|
175 |
+
|
176 |
+
## Citation [optional]
|
177 |
+
|
178 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
179 |
+
|
180 |
+
**BibTeX:**
|
181 |
+
|
182 |
+
[More Information Needed]
|
183 |
+
|
184 |
+
**APA:**
|
185 |
+
|
186 |
+
[More Information Needed]
|
187 |
+
|
188 |
+
## Glossary [optional]
|
189 |
+
|
190 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
191 |
+
|
192 |
+
[More Information Needed]
|
193 |
+
|
194 |
+
## More Information [optional]
|
195 |
+
|
196 |
+
[More Information Needed]
|
197 |
+
|
198 |
+
## Model Card Authors [optional]
|
199 |
+
|
200 |
+
[More Information Needed]
|
201 |
+
|
202 |
+
## Model Card Contact
|
203 |
+
|
204 |
+
[More Information Needed]
|
205 |
+
### Framework versions
|
206 |
+
|
207 |
+
- PEFT 0.16.0
|
qwen_loRA/adapter_config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpha_pattern": {},
|
3 |
+
"auto_mapping": null,
|
4 |
+
"base_model_name_or_path": "Qwen/Qwen3-0.6B-Base",
|
5 |
+
"bias": "none",
|
6 |
+
"corda_config": null,
|
7 |
+
"eva_config": null,
|
8 |
+
"exclude_modules": null,
|
9 |
+
"fan_in_fan_out": false,
|
10 |
+
"inference_mode": true,
|
11 |
+
"init_lora_weights": true,
|
12 |
+
"layer_replication": null,
|
13 |
+
"layers_pattern": null,
|
14 |
+
"layers_to_transform": null,
|
15 |
+
"loftq_config": {},
|
16 |
+
"lora_alpha": 16,
|
17 |
+
"lora_bias": false,
|
18 |
+
"lora_dropout": 0.05,
|
19 |
+
"megatron_config": null,
|
20 |
+
"megatron_core": "megatron.core",
|
21 |
+
"modules_to_save": null,
|
22 |
+
"peft_type": "LORA",
|
23 |
+
"qalora_group_size": 16,
|
24 |
+
"r": 8,
|
25 |
+
"rank_pattern": {},
|
26 |
+
"revision": null,
|
27 |
+
"target_modules": [
|
28 |
+
"v_proj",
|
29 |
+
"q_proj"
|
30 |
+
],
|
31 |
+
"task_type": "CAUSAL_LM",
|
32 |
+
"trainable_token_indices": null,
|
33 |
+
"use_dora": false,
|
34 |
+
"use_qalora": false,
|
35 |
+
"use_rslora": false
|
36 |
+
}
|
qwen_loRA/adapter_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c32e6b6f21996e7a38d4d017c3e9addc2a6aa24e0e03f13fe6af64052cfc0701
|
3 |
+
size 4602248
|
qwen_loRA/reward_head.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12b7d7d367d489625b74f6e6d9560acfb0349e6d4042c8f507b4484dd4a76fa5
|
3 |
+
size 6021
|
requirements.txt
CHANGED
@@ -6,3 +6,6 @@ numpy<2
|
|
6 |
transformers
|
7 |
pillow
|
8 |
python-multipart
|
|
|
|
|
|
|
|
6 |
transformers
|
7 |
pillow
|
8 |
python-multipart
|
9 |
+
gensim
|
10 |
+
peft
|
11 |
+
praw
|
save_cbow_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gensim.downloader as api
|
2 |
+
model = api.load("glove-twitter-200")
|
3 |
+
print("Model loaded.")
|
4 |
+
|
5 |
+
print("new-york" in model.key_to_index) # β
True if token is present
|
6 |
+
print("new" in model.key_to_index) # β
Also true
|
7 |
+
print("new york" in model.key_to_index) # β False β space not valid
|
8 |
+
|
9 |
+
# Optional: print 5 most similar to test
|
10 |
+
if "new-york" in model.key_to_index:
|
11 |
+
print(model.most_similar("new-york"))
|
12 |
+
|
13 |
+
compound_terms = [key for key in model.key_to_index if "-" in key]
|
14 |
+
print(f"Sample compound tokens: {compound_terms[:10]}")
|
15 |
+
|
16 |
+
model.save("../models/cbow_model.kv")
|
templates/cbow.html
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>CBOW Vector Calculator</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
max-width: 600px;
|
11 |
+
margin: auto;
|
12 |
+
padding: 1rem;
|
13 |
+
background-color: #f8f9fa;
|
14 |
+
position: relative;
|
15 |
+
}
|
16 |
+
|
17 |
+
h2 {
|
18 |
+
text-align: center;
|
19 |
+
}
|
20 |
+
|
21 |
+
form {
|
22 |
+
margin-top: 2rem;
|
23 |
+
}
|
24 |
+
|
25 |
+
textarea {
|
26 |
+
width: 100%;
|
27 |
+
height: 80px;
|
28 |
+
padding: 0.5rem;
|
29 |
+
font-size: 1rem;
|
30 |
+
resize: vertical;
|
31 |
+
}
|
32 |
+
|
33 |
+
button {
|
34 |
+
margin-top: 1rem;
|
35 |
+
width: 100%;
|
36 |
+
padding: 0.75rem;
|
37 |
+
font-size: 1rem;
|
38 |
+
background-color: #007bff;
|
39 |
+
color: white;
|
40 |
+
border: none;
|
41 |
+
border-radius: 4px;
|
42 |
+
cursor: pointer;
|
43 |
+
}
|
44 |
+
|
45 |
+
button:hover {
|
46 |
+
background-color: #0056b3;
|
47 |
+
}
|
48 |
+
|
49 |
+
.results {
|
50 |
+
margin-top: 2rem;
|
51 |
+
background-color: white;
|
52 |
+
padding: 1rem;
|
53 |
+
border-radius: 6px;
|
54 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.1);
|
55 |
+
}
|
56 |
+
|
57 |
+
.result-row {
|
58 |
+
margin: 0.5rem 0;
|
59 |
+
}
|
60 |
+
|
61 |
+
.result-word {
|
62 |
+
font-weight: bold;
|
63 |
+
}
|
64 |
+
|
65 |
+
.score {
|
66 |
+
color: #666;
|
67 |
+
margin-left: 0.5rem;
|
68 |
+
}
|
69 |
+
|
70 |
+
.floating-icons {
|
71 |
+
position: fixed;
|
72 |
+
bottom: 1rem;
|
73 |
+
right: 1rem;
|
74 |
+
display: flex;
|
75 |
+
gap: 1rem;
|
76 |
+
}
|
77 |
+
|
78 |
+
.icon-button {
|
79 |
+
background: white;
|
80 |
+
border: 1px solid #ccc;
|
81 |
+
border-radius: 50%;
|
82 |
+
width: 40px;
|
83 |
+
height: 40px;
|
84 |
+
font-size: 1.2rem;
|
85 |
+
text-align: center;
|
86 |
+
line-height: 40px;
|
87 |
+
cursor: pointer;
|
88 |
+
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
|
89 |
+
}
|
90 |
+
|
91 |
+
.modal {
|
92 |
+
display: none;
|
93 |
+
position: fixed;
|
94 |
+
z-index: 1000;
|
95 |
+
left: 0; top: 0;
|
96 |
+
width: 100%; height: 100%;
|
97 |
+
background-color: rgba(0,0,0,0.4);
|
98 |
+
}
|
99 |
+
|
100 |
+
.modal-content {
|
101 |
+
background-color: #fff;
|
102 |
+
margin: 10% auto;
|
103 |
+
padding: 2rem;
|
104 |
+
border-radius: 8px;
|
105 |
+
max-width: 400px;
|
106 |
+
position: relative;
|
107 |
+
}
|
108 |
+
|
109 |
+
.close {
|
110 |
+
position: absolute;
|
111 |
+
top: 0.5rem;
|
112 |
+
right: 0.75rem;
|
113 |
+
font-size: 1.2rem;
|
114 |
+
cursor: pointer;
|
115 |
+
}
|
116 |
+
</style>
|
117 |
+
</head>
|
118 |
+
<body>
|
119 |
+
<h2>CBOW Vector Calculator</h2>
|
120 |
+
<form method="post" action="/cbow">
|
121 |
+
<label for="expression">Enter a word vector expression <small>(e.g. <code>king - man + woman</code>)</small>:</label><br>
|
122 |
+
<textarea name="expression" rows="4" style="width: 100%">{{ expression or "" }}</textarea>
|
123 |
+
<button type="submit">Calculate</button>
|
124 |
+
</form>
|
125 |
+
|
126 |
+
{% if results %}
|
127 |
+
<div class="results">
|
128 |
+
{% if results and results|length > 0 %}
|
129 |
+
<p><strong>{{ expression }}</strong> β <strong>{{ results[0][0] }}</strong></p>
|
130 |
+
{% endif %}
|
131 |
+
<h3>Results:</h3>
|
132 |
+
<table style="width: 100%; border-collapse: collapse;">
|
133 |
+
<thead>
|
134 |
+
<tr>
|
135 |
+
<th style="text-align: left; padding: 0.5rem; border-bottom: 1px solid #ccc;">#</th>
|
136 |
+
<th style="text-align: left; padding: 0.5rem; border-bottom: 1px solid #ccc;">Result</th>
|
137 |
+
<th style="text-align: left; padding: 0.5rem; border-bottom: 1px solid #ccc;">Score</th>
|
138 |
+
</tr>
|
139 |
+
</thead>
|
140 |
+
<tbody>
|
141 |
+
{% for word, score in results %}
|
142 |
+
<tr>
|
143 |
+
<td style="padding: 0.5rem;">{{ loop.index }}</td>
|
144 |
+
<td style="padding: 0.5rem;">{{ word }}</td>
|
145 |
+
<td style="padding: 0.5rem;">
|
146 |
+
{% if score >= 0.4 %}
|
147 |
+
{{ "%.2f"|format(score) }}
|
148 |
+
{% else %}
|
149 |
+
Irrelevant result
|
150 |
+
{% endif %}
|
151 |
+
</td>
|
152 |
+
</tr>
|
153 |
+
{% endfor %}
|
154 |
+
</tbody>
|
155 |
+
</table>
|
156 |
+
</div>
|
157 |
+
{% endif %}
|
158 |
+
|
159 |
+
<div class="floating-icons">
|
160 |
+
<div class="icon-button" onclick="openModal('suggestionsModal')">π‘</div>
|
161 |
+
<div class="icon-button" onclick="openModal('aboutModal')">?</div>
|
162 |
+
</div>
|
163 |
+
|
164 |
+
<div id="suggestionsModal" class="modal">
|
165 |
+
<div class="modal-content">
|
166 |
+
<span class="close" onclick="closeModal('suggestionsModal')">×</span>
|
167 |
+
<h3>Suggestions</h3>
|
168 |
+
<ul>
|
169 |
+
<li>Try: <code>paris - france + italy</code></li>
|
170 |
+
<li>Try: <code>man + smart</code></li>
|
171 |
+
<li>Use <code>-</code> and <code>+</code> operators</li>
|
172 |
+
</ul>
|
173 |
+
</div>
|
174 |
+
</div>
|
175 |
+
|
176 |
+
<div id="aboutModal" class="modal">
|
177 |
+
<div class="modal-content">
|
178 |
+
<span class="close" onclick="closeModal('aboutModal')">×</span>
|
179 |
+
<h3>About</h3>
|
180 |
+
<p>This tool calculates vector arithmetic of words using a pretrained CBOW model "glove-twitter-200". Built by Clement Ha.</p>
|
181 |
+
</div>
|
182 |
+
</div>
|
183 |
+
|
184 |
+
<script>
|
185 |
+
function openModal(id) {
|
186 |
+
document.getElementById(id).style.display = 'block';
|
187 |
+
}
|
188 |
+
|
189 |
+
function closeModal(id) {
|
190 |
+
document.getElementById(id).style.display = 'none';
|
191 |
+
}
|
192 |
+
|
193 |
+
window.onclick = function(event) {
|
194 |
+
const modals = document.querySelectorAll('.modal');
|
195 |
+
modals.forEach(modal => {
|
196 |
+
if (event.target == modal) {
|
197 |
+
modal.style.display = "none";
|
198 |
+
}
|
199 |
+
});
|
200 |
+
}
|
201 |
+
</script>
|
202 |
+
<script>
|
203 |
+
document.addEventListener("DOMContentLoaded", function () {
|
204 |
+
const textarea = document.querySelector("textarea[name='expression']");
|
205 |
+
const form = document.querySelector("form");
|
206 |
+
|
207 |
+
textarea.addEventListener("keydown", function (event) {
|
208 |
+
if (event.key === "Enter" && !event.shiftKey) {
|
209 |
+
event.preventDefault(); // prevent newline
|
210 |
+
form.submit(); // trigger form submission
|
211 |
+
}
|
212 |
+
});
|
213 |
+
});
|
214 |
+
</script>
|
215 |
+
|
216 |
+
<script>
|
217 |
+
document.addEventListener("DOMContentLoaded", function () {
|
218 |
+
const form = document.querySelector("form");
|
219 |
+
const textarea = document.querySelector("textarea[name='expression']");
|
220 |
+
let resultShown = {{ 'true' if results else 'false' }};
|
221 |
+
|
222 |
+
// 1. Validate spacing between tokens
|
223 |
+
form.addEventListener("submit", function (event) {
|
224 |
+
const input = textarea.value.trim();
|
225 |
+
|
226 |
+
// Simple regex to find missing spaces (e.g., "word+word", "word-word")
|
227 |
+
const spacingIssues = input.match(/\b\w+[\+\-]\w+\b/);
|
228 |
+
|
229 |
+
if (spacingIssues) {
|
230 |
+
event.preventDefault();
|
231 |
+
const problem = spacingIssues[0];
|
232 |
+
const suggestion = problem.replace(/([\+\-])/, ' $1 ');
|
233 |
+
alert(`β οΈ It looks like you missed spacing in: "${problem}".\nDid you mean: "${suggestion}"?`);
|
234 |
+
}
|
235 |
+
});
|
236 |
+
|
237 |
+
// 2. Clear textarea when focused again, only if result was shown
|
238 |
+
if (resultShown) {
|
239 |
+
textarea.addEventListener("focus", () => {
|
240 |
+
textarea.value = "";
|
241 |
+
resultShown = false; // prevent it from clearing again on next focus
|
242 |
+
});
|
243 |
+
}
|
244 |
+
});
|
245 |
+
</script>
|
246 |
+
|
247 |
+
</body>
|
248 |
+
</html>
|
vit_captioning/generate.py
CHANGED
@@ -10,7 +10,7 @@ import argparse
|
|
10 |
|
11 |
|
12 |
class CaptionGenerator:
|
13 |
-
def __init__(self, model_type: str, checkpoint_path: str, quantized=False):
|
14 |
print(f"Loading {model_type} | Quantized: {quantized}")
|
15 |
# Setup device
|
16 |
if torch.cuda.is_available():
|
@@ -25,9 +25,10 @@ class CaptionGenerator:
|
|
25 |
|
26 |
# Load tokenizer
|
27 |
#self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
31 |
|
32 |
# Select encoder, processor, output dim
|
33 |
if model_type == "ViTEncoder":
|
@@ -37,8 +38,11 @@ class CaptionGenerator:
|
|
37 |
elif model_type == "CLIPEncoder":
|
38 |
self.encoder = CLIPEncoder().to(self.device)
|
39 |
self.encoder_dim = 512
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
42 |
else:
|
43 |
raise ValueError("Unknown model type")
|
44 |
|
@@ -109,7 +113,8 @@ if __name__ == "__main__":
|
|
109 |
|
110 |
generator = CaptionGenerator(
|
111 |
model_type=args.model,
|
112 |
-
checkpoint_path=args.checkpoint
|
|
|
113 |
)
|
114 |
|
115 |
captions = generator.generate_caption(args.image)
|
|
|
10 |
|
11 |
|
12 |
class CaptionGenerator:
|
13 |
+
def __init__(self, model_type: str, checkpoint_path: str, quantized=False, runAsContainer=False):
|
14 |
print(f"Loading {model_type} | Quantized: {quantized}")
|
15 |
# Setup device
|
16 |
if torch.cuda.is_available():
|
|
|
25 |
|
26 |
# Load tokenizer
|
27 |
#self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
28 |
+
if (runAsContainer):
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained('/models/bert-tokenizer')
|
30 |
+
else:
|
31 |
+
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
32 |
|
33 |
# Select encoder, processor, output dim
|
34 |
if model_type == "ViTEncoder":
|
|
|
38 |
elif model_type == "CLIPEncoder":
|
39 |
self.encoder = CLIPEncoder().to(self.device)
|
40 |
self.encoder_dim = 512
|
41 |
+
if (runAsContainer):
|
42 |
+
self.processor = CLIPProcessor.from_pretrained("/models/clip")
|
43 |
+
else:
|
44 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
45 |
+
|
46 |
else:
|
47 |
raise ValueError("Unknown model type")
|
48 |
|
|
|
113 |
|
114 |
generator = CaptionGenerator(
|
115 |
model_type=args.model,
|
116 |
+
checkpoint_path=args.checkpoint,
|
117 |
+
runAsContainer=True
|
118 |
)
|
119 |
|
120 |
captions = generator.generate_caption(args.image)
|
vit_captioning/models/encoder.py
CHANGED
@@ -13,9 +13,7 @@ class ViTEncoder(nn.Module):
|
|
13 |
|
14 |
#weights = ViT_B_16_Weights.DEFAULT
|
15 |
|
16 |
-
|
17 |
-
#HF needs all model downloads to a special read-write cache dir
|
18 |
-
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', cache_dir="/tmp")
|
19 |
|
20 |
def forward(self, pixel_values):
|
21 |
|
@@ -34,10 +32,7 @@ class CLIPEncoder(nn.Module):
|
|
34 |
def __init__(self):
|
35 |
super(CLIPEncoder, self).__init__()
|
36 |
#self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
37 |
-
#HF needs all model downloads to a special read-write cache dir
|
38 |
-
#self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir="/tmp")
|
39 |
self.clip = CLIPModel.from_pretrained("/models/clip")
|
40 |
-
|
41 |
def forward(self, pixel_values):
|
42 |
# β
Directly get the pooled image features (already the final representation)
|
43 |
image_features = self.clip.get_image_features(pixel_values=pixel_values)
|
|
|
13 |
|
14 |
#weights = ViT_B_16_Weights.DEFAULT
|
15 |
|
16 |
+
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
|
|
|
|
17 |
|
18 |
def forward(self, pixel_values):
|
19 |
|
|
|
32 |
def __init__(self):
|
33 |
super(CLIPEncoder, self).__init__()
|
34 |
#self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
|
35 |
self.clip = CLIPModel.from_pretrained("/models/clip")
|
|
|
36 |
def forward(self, pixel_values):
|
37 |
# β
Directly get the pooled image features (already the final representation)
|
38 |
image_features = self.clip.get_image_features(pixel_values=pixel_values)
|
vit_captioning/static/landing.html
CHANGED
@@ -13,11 +13,11 @@
|
|
13 |
<a href="/captioning" class="block w-full bg-blue-600 hover:bg-blue-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
|
14 |
πΌοΈ Image Captioning
|
15 |
</a>
|
16 |
-
<a href="/
|
17 |
-
|
18 |
</a>
|
19 |
-
<a href="/
|
20 |
-
π
|
21 |
</a>
|
22 |
<!-- Add more project links here -->
|
23 |
</div>
|
|
|
13 |
<a href="/captioning" class="block w-full bg-blue-600 hover:bg-blue-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
|
14 |
πΌοΈ Image Captioning
|
15 |
</a>
|
16 |
+
<a href="/cbow" class="block w-full bg-green-600 hover:bg-green-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
|
17 |
+
π§ββοΈ Word Alchemy
|
18 |
</a>
|
19 |
+
<a href="/contentdistillery" class="block w-full bg-red-600 hover:bg-red-700 text-white py-3 rounded-lg shadow text-lg font-semibold">
|
20 |
+
π Content Distillery
|
21 |
</a>
|
22 |
<!-- Add more project links here -->
|
23 |
</div>
|