Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,57 +1,53 @@
|
|
1 |
# app.py – encoder-only demo for bert-beatrix-2048
|
2 |
-
# ------------------------------------------------------------------
|
3 |
# launch: python app.py
|
4 |
-
#
|
5 |
-
|
6 |
-
import json, re, sys
|
7 |
from pathlib import Path, PurePosixPath
|
8 |
|
|
|
9 |
import gradio as gr
|
10 |
import spaces
|
11 |
-
import torch
|
12 |
from huggingface_hub import snapshot_download
|
13 |
|
14 |
from bert_handler import create_handler_from_checkpoint
|
15 |
|
16 |
|
17 |
# ------------------------------------------------------------------
|
18 |
-
# 0.
|
19 |
-
# ------------------------------------------------------------------
|
20 |
REPO_ID = "AbstractPhil/bert-beatrix-2048"
|
21 |
-
|
22 |
|
23 |
-
snapshot_download(
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
-
|
30 |
-
for k,
|
31 |
if "--" in v:
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
cfg["auto_map"] = auto_map
|
36 |
-
cfg_path.write_text(json.dumps(cfg, indent=2))
|
37 |
-
print("🛠️ Patched config.json → auto_map now points at local modules")
|
38 |
|
39 |
-
|
40 |
-
# ------------------------------------------------------------------
|
41 |
-
# 1. Model / tokenizer -------------------------------------------
|
42 |
# ------------------------------------------------------------------
|
43 |
-
|
|
|
44 |
full_model = full_model.eval().cuda()
|
45 |
|
46 |
-
encoder
|
47 |
-
embeddings
|
48 |
-
emb_ln
|
49 |
-
emb_drop
|
|
|
50 |
|
51 |
-
|
52 |
-
# ------------------------------------------------------------------
|
53 |
-
# 2. Symbolic token set ------------------------------------------
|
54 |
# ------------------------------------------------------------------
|
|
|
55 |
SYMBOLIC_ROLES = [
|
56 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
57 |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
|
@@ -60,108 +56,124 @@ SYMBOLIC_ROLES = [
|
|
60 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
61 |
"<fabric>", "<jewelry>",
|
62 |
]
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
if missing:
|
67 |
-
sys.exit(f"❌ Tokenizer is missing {missing}")
|
68 |
|
69 |
|
70 |
# ------------------------------------------------------------------
|
71 |
-
# 3. Encoder
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
76 |
-
def encode_and_trace(text: str, _ignored): # all roles auto-selected
|
77 |
"""
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
"""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
|
127 |
# ------------------------------------------------------------------
|
128 |
-
# 4. Gradio UI
|
129 |
-
# ------------------------------------------------------------------
|
130 |
def build_interface():
|
131 |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
|
132 |
-
gr.Markdown(
|
133 |
-
"## 🧠 Symbolic Encoder Inspector\n"
|
134 |
-
"Enter text containing the `<role>` tokens.\n"
|
135 |
-
"Cosine probe **and** real mask-prediction accuracy are shown."
|
136 |
-
)
|
137 |
|
138 |
with gr.Row():
|
139 |
with gr.Column():
|
140 |
-
txt
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
)
|
145 |
-
|
146 |
-
roles = gr.CheckboxGroup(
|
147 |
-
choices=SYMBOLIC_ROLES,
|
148 |
-
label="(all roles auto-selected)",
|
149 |
-
value=SYMBOLIC_ROLES,
|
150 |
-
interactive=False,
|
151 |
-
)
|
152 |
-
btn = gr.Button("Run probe + MLM check")
|
153 |
with gr.Column():
|
154 |
-
|
155 |
-
|
156 |
-
out_cnt = gr.Textbox(label="
|
157 |
-
out_acc = gr.Textbox(label="Mask-prediction accuracy")
|
158 |
-
|
159 |
-
btn.click(encode_and_trace,
|
160 |
-
inputs=[txt, roles],
|
161 |
-
outputs=[out_tok, out_norm, out_cnt, out_acc])
|
162 |
|
|
|
163 |
return demo
|
164 |
|
165 |
|
166 |
-
if __name__
|
167 |
-
build_interface().launch()
|
|
|
1 |
# app.py – encoder-only demo for bert-beatrix-2048
|
|
|
2 |
# launch: python app.py
|
3 |
+
# -----------------------------------------------
|
4 |
+
import json, re, sys, math
|
|
|
5 |
from pathlib import Path, PurePosixPath
|
6 |
|
7 |
+
import torch, torch.nn.functional as F
|
8 |
import gradio as gr
|
9 |
import spaces
|
|
|
10 |
from huggingface_hub import snapshot_download
|
11 |
|
12 |
from bert_handler import create_handler_from_checkpoint
|
13 |
|
14 |
|
15 |
# ------------------------------------------------------------------
|
16 |
+
# 0. Download & patch HF checkpoint --------------------------------
|
|
|
17 |
REPO_ID = "AbstractPhil/bert-beatrix-2048"
|
18 |
+
LOCAL_CKPT = "bert-beatrix-2048"
|
19 |
|
20 |
+
snapshot_download(
|
21 |
+
repo_id=REPO_ID,
|
22 |
+
revision="main",
|
23 |
+
local_dir=LOCAL_CKPT,
|
24 |
+
local_dir_use_symlinks=False,
|
25 |
+
)
|
26 |
|
27 |
+
# → strip repo prefix in auto_map (one-time)
|
28 |
+
cfg_path = Path(LOCAL_CKPT) / "config.json"
|
29 |
+
with cfg_path.open() as f: cfg = json.load(f)
|
30 |
|
31 |
+
amap = cfg.get("auto_map", {})
|
32 |
+
for k,v in amap.items():
|
33 |
if "--" in v:
|
34 |
+
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
|
35 |
+
cfg["auto_map"] = amap
|
36 |
+
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
|
|
|
|
|
|
|
37 |
|
|
|
|
|
|
|
38 |
# ------------------------------------------------------------------
|
39 |
+
# 1. Load model & components --------------------------------------
|
40 |
+
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
|
41 |
full_model = full_model.eval().cuda()
|
42 |
|
43 |
+
encoder = full_model.bert.encoder
|
44 |
+
embeddings = full_model.bert.embeddings
|
45 |
+
emb_ln = full_model.bert.emb_ln
|
46 |
+
emb_drop = full_model.bert.emb_drop
|
47 |
+
mlm_head = full_model.cls # prediction head
|
48 |
|
|
|
|
|
|
|
49 |
# ------------------------------------------------------------------
|
50 |
+
# 2. Symbolic roles -------------------------------------------------
|
51 |
SYMBOLIC_ROLES = [
|
52 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
53 |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
|
|
|
56 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
57 |
"<fabric>", "<jewelry>",
|
58 |
]
|
59 |
+
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
|
60 |
+
for t in SYMBOLIC_ROLES):
|
61 |
+
sys.exit("❌ tokenizer missing special tokens")
|
62 |
|
63 |
+
# Quick helpers
|
64 |
+
MASK = tokenizer.mask_token
|
|
|
|
|
65 |
|
66 |
|
67 |
# ------------------------------------------------------------------
|
68 |
+
# 3. Encoder-plus-MLM logic ---------------------------------------
|
69 |
+
def cosine(a,b):
|
70 |
+
return torch.nn.functional.cosine_similarity(a,b,dim=-1)
|
71 |
|
72 |
+
def pool_accuracy(ids, logits, pool_mask):
|
|
|
73 |
"""
|
74 |
+
ids : (S,) gold token ids
|
75 |
+
logits : (S,V) MLM logits
|
76 |
+
pool_mask : bool (S,) which tokens belong to the candidate pool
|
77 |
+
returns accuracy over masked positions only (if none, return 0)
|
78 |
"""
|
79 |
+
idx = pool_mask.nonzero(as_tuple=False).flatten()
|
80 |
+
if idx.numel()==0: return 0.0
|
81 |
+
preds = logits.argmax(-1)[idx]
|
82 |
+
gold = ids[idx]
|
83 |
+
return (preds==gold).float().mean().item()
|
84 |
+
|
85 |
+
|
86 |
+
@spaces.GPU
|
87 |
+
def encode_and_trace(text, selected_roles):
|
88 |
+
# if user unchecked everything we treat as "all"
|
89 |
+
if not selected_roles:
|
90 |
+
selected_roles = SYMBOLIC_ROLES
|
91 |
+
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
|
92 |
+
|
93 |
+
# ---- Tokenise & encode once ----
|
94 |
+
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
95 |
+
ids, att = batch.input_ids, batch.attention_mask
|
96 |
+
x = emb_drop(emb_ln(embeddings(ids)))
|
97 |
+
ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1])
|
98 |
+
enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H)
|
99 |
+
|
100 |
+
# ---- compute max-cos per token (F-0/F-1) ----
|
101 |
+
role_mat = embeddings.word_embeddings(
|
102 |
+
torch.tensor(sorted(sel_ids), device=enc.device)
|
103 |
+
) # (R,H)
|
104 |
+
cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R)
|
105 |
+
maxcos, argrole = cos.max(-1) # (S,)
|
106 |
+
|
107 |
+
# ---- split tokens into High / Low half (F-2) ----
|
108 |
+
S = len(ids[0])
|
109 |
+
sort_idx = maxcos.argsort(descending=True)
|
110 |
+
hi_idx = sort_idx[: S//2]
|
111 |
+
lo_idx = sort_idx[S//2:]
|
112 |
+
|
113 |
+
# container for summary text
|
114 |
+
report_lines = []
|
115 |
+
|
116 |
+
# ---- pool builder helper (uses S-4…S-7) ----
|
117 |
+
def greedy_pool(token_indices, direction):
|
118 |
+
# direction=='hi' or 'lo'
|
119 |
+
pool = []
|
120 |
+
μ = None
|
121 |
+
for tix in token_indices:
|
122 |
+
t_vec = enc[tix]
|
123 |
+
# incremental update (S-7)
|
124 |
+
μ = t_vec if μ is None else (μ*len(pool) + t_vec)/(len(pool)+1)
|
125 |
+
pool.append(tix)
|
126 |
+
# mask all tokens *except* this pool
|
127 |
+
masked_ids = ids.clone()
|
128 |
+
keep = torch.tensor(pool, device=ids.device)
|
129 |
+
mask_mask = torch.ones_like(ids, dtype=torch.bool)
|
130 |
+
mask_mask[0, keep] = False
|
131 |
+
masked_ids[mask_mask] = tokenizer.mask_token_id
|
132 |
+
# run MLM
|
133 |
+
with torch.no_grad():
|
134 |
+
logits = mlm_head(full_model.bert.emb_dl(enc.unsqueeze(0))).logits[0]
|
135 |
+
acc = pool_accuracy(ids[0], logits, ~mask_mask[0])
|
136 |
+
report_lines.append(f"{direction}-pool size {len(pool)} → acc={acc:.2f}")
|
137 |
+
if acc >= 0.5: break
|
138 |
+
return pool, acc
|
139 |
+
|
140 |
+
pool_lo, acc_lo = greedy_pool(lo_idx, "low")
|
141 |
+
pool_hi, acc_hi = greedy_pool(hi_idx, "high")
|
142 |
+
|
143 |
+
# ---- package textual result ----
|
144 |
+
res_json = {
|
145 |
+
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
|
146 |
+
"Low accuracy": f"{acc_lo:.2f}",
|
147 |
+
"High-pool tokens":tokenizer.decode(ids[0, pool_hi]),
|
148 |
+
"High accuracy": f"{acc_hi:.2f}",
|
149 |
+
"Trace": "\n".join(report_lines)
|
150 |
+
}
|
151 |
+
# three outputs expected by UI
|
152 |
+
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
|
153 |
|
154 |
|
155 |
# ------------------------------------------------------------------
|
156 |
+
# 4. Gradio UI -----------------------------------------------------
|
|
|
157 |
def build_interface():
|
158 |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
|
159 |
+
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
|
|
|
|
|
|
|
|
|
160 |
|
161 |
with gr.Row():
|
162 |
with gr.Column():
|
163 |
+
txt = gr.Textbox(label="Prompt", lines=3)
|
164 |
+
roles= gr.CheckboxGroup(
|
165 |
+
choices=SYMBOLIC_ROLES, label="Roles",
|
166 |
+
value=SYMBOLIC_ROLES # pre-checked
|
167 |
)
|
168 |
+
btn = gr.Button("Run")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
with gr.Column():
|
170 |
+
out_json = gr.Textbox(label="Result JSON")
|
171 |
+
out_max = gr.Textbox(label="Max cos")
|
172 |
+
out_cnt = gr.Textbox(label="# roles")
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
|
175 |
return demo
|
176 |
|
177 |
|
178 |
+
if __name__=="__main__":
|
179 |
+
build_interface().launch()
|