Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
-
# app.py – encoder-only demo
|
|
|
2 |
# launch: python app.py
|
3 |
-
|
4 |
-
import json
|
5 |
-
import sys
|
6 |
from pathlib import Path, PurePosixPath
|
7 |
-
from itertools import islice
|
8 |
|
9 |
-
import
|
10 |
-
import spaces
|
11 |
-
import torch
|
12 |
-
import torch.nn.functional as F
|
13 |
from huggingface_hub import snapshot_download
|
14 |
|
15 |
from bert_handler import create_handler_from_checkpoint
|
16 |
|
17 |
# ------------------------------------------------------------------
|
18 |
-
# 0.
|
19 |
# ------------------------------------------------------------------
|
20 |
-
REPO_ID
|
21 |
-
LOCAL_CKPT
|
22 |
|
23 |
snapshot_download(
|
24 |
repo_id=REPO_ID,
|
@@ -28,35 +25,32 @@ snapshot_download(
|
|
28 |
)
|
29 |
|
30 |
cfg_path = Path(LOCAL_CKPT) / "config.json"
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
auto_map = cfg.get("auto_map", {})
|
35 |
-
patched = False
|
36 |
for k, v in auto_map.items():
|
37 |
-
if "--" in v:
|
38 |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
with cfg_path.open("w") as f:
|
44 |
-
json.dump(cfg, f, indent=2)
|
45 |
-
print("🛠️ Patched config.json → auto_map paths fixed")
|
46 |
|
47 |
# ------------------------------------------------------------------
|
48 |
-
# 1.
|
49 |
# ------------------------------------------------------------------
|
50 |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
|
51 |
full_model = full_model.eval().cuda()
|
52 |
|
|
|
53 |
encoder = full_model.bert.encoder
|
54 |
embeddings = full_model.bert.embeddings
|
|
|
55 |
emb_ln = full_model.bert.emb_ln
|
56 |
emb_drop = full_model.bert.emb_drop
|
57 |
|
58 |
# ------------------------------------------------------------------
|
59 |
-
# 2.
|
60 |
# ------------------------------------------------------------------
|
61 |
SYMBOLIC_ROLES = [
|
62 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
@@ -66,114 +60,108 @@ SYMBOLIC_ROLES = [
|
|
66 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
67 |
"<fabric>", "<jewelry>",
|
68 |
]
|
69 |
-
|
70 |
-
|
71 |
-
missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
|
72 |
if missing:
|
73 |
-
sys.exit(f"❌ Tokenizer
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
# ------------------------------------------------------------------
|
78 |
-
def reduce_to_three(table):
|
79 |
-
"""
|
80 |
-
table : list of dicts {role, token, score}
|
81 |
-
repeatedly remove lowest and highest,
|
82 |
-
replace with their average,
|
83 |
-
until len(table)==3.
|
84 |
-
"""
|
85 |
-
working = table[:]
|
86 |
-
working.sort(key=lambda x: x["score"])
|
87 |
-
while len(working) > 3:
|
88 |
-
low = working.pop(0)
|
89 |
-
high = working.pop(-1)
|
90 |
-
merged = {
|
91 |
-
"role": f"{high['role']}|{low['role']}",
|
92 |
-
"token": f"{high['token']}/{low['token']}",
|
93 |
-
"score": (high["score"] + low["score"]) / 2.0,
|
94 |
-
}
|
95 |
-
working.append(merged)
|
96 |
-
working.sort(key=lambda x: x["score"])
|
97 |
-
# highest first for display
|
98 |
-
working.sort(key=lambda x: x["score"], reverse=True)
|
99 |
-
return working
|
100 |
|
101 |
# ------------------------------------------------------------------
|
102 |
-
#
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
with torch.no_grad():
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
|
116 |
-
hs = encoder(x, attention_mask=ext) # (1,S,H)
|
117 |
-
|
118 |
-
# token-level embeddings (before LN) for similarity calc
|
119 |
-
token_emb = embeddings(ids).squeeze(0) # (S,H)
|
120 |
-
|
121 |
-
rows = []
|
122 |
-
for role in selected_roles:
|
123 |
-
rid = tokenizer.convert_tokens_to_ids(role)
|
124 |
-
rvec = embeddings.word_embeddings.weight[rid] # (H,)
|
125 |
-
# cosine similarity to every *input* token embedding
|
126 |
-
sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
|
127 |
-
best = torch.argmax(sims).item()
|
128 |
-
rows.append({
|
129 |
-
"role" : role,
|
130 |
-
"token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
|
131 |
-
"score": sims[best].item()
|
132 |
-
})
|
133 |
-
|
134 |
-
if not rows:
|
135 |
-
return "(none selected)","",""
|
136 |
-
|
137 |
-
final3 = reduce_to_three(rows)
|
138 |
-
out_strs = [f"{r['role']} ↔ {r['token']} ({r['score']:+.2f})" for r in final3]
|
139 |
-
# pad so we always return 3 strings
|
140 |
-
while len(out_strs) < 3:
|
141 |
-
out_strs.append("")
|
142 |
-
return out_strs[0], out_strs[1], out_strs[2]
|
143 |
|
144 |
# ------------------------------------------------------------------
|
145 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# ------------------------------------------------------------------
|
147 |
-
|
|
|
148 |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
|
149 |
gr.Markdown(
|
150 |
-
"
|
151 |
-
"
|
152 |
-
"• compute role ↔ token cosine scores\n"
|
153 |
-
"• iteratively merge low+high pairs until **3 composite buckets** remain."
|
154 |
)
|
155 |
-
|
156 |
with gr.Row():
|
157 |
with gr.Column():
|
158 |
-
txt = gr.Textbox(
|
159 |
-
label="Input with Symbolic Tokens",
|
160 |
-
placeholder="Example: A <subject> wearing <upper_body_clothing> …",
|
161 |
-
lines=3,
|
162 |
-
)
|
163 |
roles = gr.CheckboxGroup(
|
164 |
-
|
165 |
-
|
|
|
166 |
)
|
167 |
-
btn = gr.Button("
|
168 |
with gr.Column():
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])
|
174 |
|
|
|
175 |
return demo
|
176 |
|
177 |
-
|
178 |
if __name__ == "__main__":
|
179 |
-
|
|
|
1 |
+
# app.py – encoder-only demo + pool-and-test prototype
|
2 |
+
# ----------------------------------------------------
|
3 |
# launch: python app.py
|
4 |
+
# UI: http://localhost:7860
|
5 |
+
import json, re, sys, math
|
|
|
6 |
from pathlib import Path, PurePosixPath
|
|
|
7 |
|
8 |
+
import torch, torch.nn.functional as F
|
9 |
+
import gradio as gr, spaces
|
|
|
|
|
10 |
from huggingface_hub import snapshot_download
|
11 |
|
12 |
from bert_handler import create_handler_from_checkpoint
|
13 |
|
14 |
# ------------------------------------------------------------------
|
15 |
+
# 0. One-time patch of auto_map in config.json
|
16 |
# ------------------------------------------------------------------
|
17 |
+
REPO_ID = "AbstractPhil/bert-beatrix-2048"
|
18 |
+
LOCAL_CKPT = "bert-beatrix-2048"
|
19 |
|
20 |
snapshot_download(
|
21 |
repo_id=REPO_ID,
|
|
|
25 |
)
|
26 |
|
27 |
cfg_path = Path(LOCAL_CKPT) / "config.json"
|
28 |
+
cfg = json.loads(cfg_path.read_text())
|
29 |
+
auto_map = cfg.get("auto_map", {})
|
30 |
+
changed = False
|
|
|
|
|
31 |
for k, v in auto_map.items():
|
32 |
+
if "--" in v: # strip “repo--”
|
33 |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
|
34 |
+
changed = True
|
35 |
+
if changed:
|
36 |
+
cfg_path.write_text(json.dumps(cfg, indent=2))
|
37 |
+
print("🛠️ Patched config.json → auto_map points to local modules")
|
|
|
|
|
|
|
38 |
|
39 |
# ------------------------------------------------------------------
|
40 |
+
# 1. Load model + tokenizer with BERTHandler
|
41 |
# ------------------------------------------------------------------
|
42 |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
|
43 |
full_model = full_model.eval().cuda()
|
44 |
|
45 |
+
# pull encoder & embedding stack
|
46 |
encoder = full_model.bert.encoder
|
47 |
embeddings = full_model.bert.embeddings
|
48 |
+
emb_weight = embeddings.word_embeddings.weight # <- correct tensor
|
49 |
emb_ln = full_model.bert.emb_ln
|
50 |
emb_drop = full_model.bert.emb_drop
|
51 |
|
52 |
# ------------------------------------------------------------------
|
53 |
+
# 2. Symbolic roles
|
54 |
# ------------------------------------------------------------------
|
55 |
SYMBOLIC_ROLES = [
|
56 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
|
|
60 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
61 |
"<fabric>", "<jewelry>",
|
62 |
]
|
63 |
+
missing = [t for t in SYMBOLIC_ROLES
|
64 |
+
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
|
|
|
65 |
if missing:
|
66 |
+
sys.exit(f"❌ Tokenizer missing {missing}")
|
67 |
|
68 |
+
MASK_ID = tokenizer.mask_token_id
|
69 |
+
MASK_TOK = tokenizer.mask_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# ------------------------------------------------------------------
|
72 |
+
# helpers -----------------------------------------------------------
|
73 |
+
def contextual_vectors(ids, mask):
|
74 |
+
"""run through embedding→encoder, return (S,H) hidden states"""
|
75 |
+
x = emb_drop(emb_ln(embeddings(ids))) # (1,S,H)
|
76 |
+
ext = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
|
77 |
+
return encoder(x, attention_mask=ext).squeeze(0) # (S,H)
|
78 |
+
|
79 |
+
def pool_accuracy(ids, mask, pool_positions):
|
80 |
+
"""mask positions in pool, predict, calc accuracy"""
|
81 |
+
masked = ids.clone()
|
82 |
+
masked[0, pool_positions] = MASK_ID
|
83 |
with torch.no_grad():
|
84 |
+
logits = full_model(masked, attention_mask=mask).logits[0]
|
85 |
+
preds = logits.argmax(-1)
|
86 |
+
gold = ids.squeeze(0)
|
87 |
+
correct = (preds[pool_positions] == gold[pool_positions]).sum().item()
|
88 |
+
return correct / len(pool_positions) if pool_positions else 0.0
|
89 |
+
|
90 |
+
# cosine utility
|
91 |
+
def cos(a, b): return F.cosine_similarity(a, b, dim=-1, eps=1e-8).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# ------------------------------------------------------------------
|
94 |
+
# 3. Core routine ---------------------------------------------------
|
95 |
+
@spaces.GPU
|
96 |
+
def encode_and_trace(text: str, picked_roles: list[str]):
|
97 |
+
# -------- tokenise ----------
|
98 |
+
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
99 |
+
ids, attn = batch.input_ids, batch.attention_mask
|
100 |
+
hid = contextual_vectors(ids, attn) # (S,H)
|
101 |
+
|
102 |
+
# -------- decide which roles we analyse ----------
|
103 |
+
present = {tid: pos for pos, tid in enumerate(ids[0].tolist())
|
104 |
+
if tid in {tokenizer.convert_tokens_to_ids(r) for r in SYMBOLIC_ROLES}}
|
105 |
+
if picked_roles:
|
106 |
+
present = {tid: pos for tid, pos in present.items()
|
107 |
+
if tokenizer.convert_ids_to_tokens([tid])[0] in picked_roles}
|
108 |
+
if not present:
|
109 |
+
return "No symbolic tokens in sentence", "", ""
|
110 |
+
|
111 |
+
# -------- similarity scores ----------
|
112 |
+
sims = []
|
113 |
+
for tid, pos in present.items():
|
114 |
+
rvec = emb_weight[tid] # static embedding
|
115 |
+
cvec = hid[pos] # contextual
|
116 |
+
sims.append((cos(cvec, rvec), tid, pos))
|
117 |
+
sims.sort() # low → high
|
118 |
+
# pools: bottom-2, top-2 (expand later)
|
119 |
+
low_pool, high_pool = sims[:2], sims[-2:]
|
120 |
+
accepted = []
|
121 |
+
|
122 |
+
for grow in range(1 + math.ceil(len(sims)/2)): # ≤26 shots
|
123 |
+
for tag, pool in [("low", low_pool), ("high", high_pool)]:
|
124 |
+
pool_pos = [p for _,_,p in pool]
|
125 |
+
acc = pool_accuracy(ids, attn, pool_pos)
|
126 |
+
if acc >= 0.5: # category accepted
|
127 |
+
roles = [tokenizer.convert_ids_to_tokens([tid])[0] for _,tid,_ in pool]
|
128 |
+
accepted.append(f"{tag}:{roles} (acc {acc:.2f})")
|
129 |
+
if accepted: break # stop once something passed
|
130 |
+
# grow pools by two (if any left)
|
131 |
+
next_lo = sims[2+grow*2 : 4+grow*2]
|
132 |
+
next_hi = sims[-4-grow*2 : -2-grow*2] if 4+grow*2 <= len(sims) else []
|
133 |
+
low_pool += next_lo
|
134 |
+
high_pool += next_hi
|
135 |
+
|
136 |
+
if not accepted:
|
137 |
+
accepted = ["(none hit 50 %)"]
|
138 |
+
|
139 |
+
return ", ".join(accepted), f"{len(present)} roles analysed", f"{text[:80]}…"
|
140 |
+
|
141 |
# ------------------------------------------------------------------
|
142 |
+
# 4. UI -------------------------------------------------------------
|
143 |
+
def build_ui():
|
144 |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
|
145 |
gr.Markdown(
|
146 |
+
"## 🧠 Symbolic Encoder Inspector \n"
|
147 |
+
"Select roles, paste text, and watch the pool-and-test prototype work."
|
|
|
|
|
148 |
)
|
|
|
149 |
with gr.Row():
|
150 |
with gr.Column():
|
151 |
+
txt = gr.Textbox(lines=3, label="Input")
|
|
|
|
|
|
|
|
|
152 |
roles = gr.CheckboxGroup(
|
153 |
+
SYMBOLIC_ROLES,
|
154 |
+
value=SYMBOLIC_ROLES,
|
155 |
+
label="Roles to consider (else all present)"
|
156 |
)
|
157 |
+
btn = gr.Button("Run")
|
158 |
with gr.Column():
|
159 |
+
out_cat = gr.Textbox(label="Accepted categories")
|
160 |
+
out_info= gr.Textbox(label="Debug")
|
161 |
+
out_excerpt = gr.Textbox(label="Excerpt")
|
|
|
|
|
162 |
|
163 |
+
btn.click(encode_and_trace, [txt, roles], [out_cat, out_info, out_excerpt])
|
164 |
return demo
|
165 |
|
|
|
166 |
if __name__ == "__main__":
|
167 |
+
build_ui().launch()
|