Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,52 +1,54 @@
|
|
1 |
-
# app.py – encoder-only demo for bert-beatrix-2048
|
2 |
-
#
|
3 |
-
# launch: python app.py
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
import json, re, sys
|
6 |
-
from pathlib import Path, PurePosixPath # ← PurePosixPath import
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
import torch
|
|
|
10 |
from huggingface_hub import snapshot_download
|
11 |
|
12 |
from bert_handler import create_handler_from_checkpoint
|
13 |
|
14 |
|
15 |
# ------------------------------------------------------------------
|
16 |
-
# 0. Download & patch config.json
|
17 |
# ------------------------------------------------------------------
|
18 |
-
REPO_ID
|
19 |
-
|
20 |
|
21 |
snapshot_download(
|
22 |
repo_id=REPO_ID,
|
23 |
revision="main",
|
24 |
-
local_dir=
|
25 |
local_dir_use_symlinks=False,
|
26 |
)
|
27 |
|
28 |
-
cfg_path = Path(
|
29 |
with cfg_path.open() as f:
|
30 |
cfg = json.load(f)
|
31 |
|
32 |
auto_map = cfg.get("auto_map", {})
|
33 |
patched = False
|
34 |
for k, v in auto_map.items():
|
35 |
-
if "--" in v:
|
36 |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
|
37 |
patched = True
|
38 |
|
39 |
if patched:
|
40 |
-
cfg["auto_map"] = auto_map
|
41 |
with cfg_path.open("w") as f:
|
42 |
json.dump(cfg, f, indent=2)
|
43 |
-
print("🛠️ Patched config.json → auto_map
|
44 |
|
45 |
|
46 |
# ------------------------------------------------------------------
|
47 |
-
# 1.
|
48 |
# ------------------------------------------------------------------
|
49 |
-
handler, full_model, tokenizer = create_handler_from_checkpoint(
|
50 |
full_model = full_model.eval().cuda()
|
51 |
|
52 |
encoder = full_model.bert.encoder
|
@@ -56,7 +58,7 @@ emb_drop = full_model.bert.emb_drop
|
|
56 |
|
57 |
|
58 |
# ------------------------------------------------------------------
|
59 |
-
# 2. Symbolic
|
60 |
# ------------------------------------------------------------------
|
61 |
SYMBOLIC_ROLES = [
|
62 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
@@ -66,48 +68,67 @@ SYMBOLIC_ROLES = [
|
|
66 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
67 |
"<fabric>", "<jewelry>",
|
68 |
]
|
69 |
-
|
70 |
-
missing = [
|
71 |
-
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
|
72 |
if missing:
|
73 |
sys.exit(f"❌ Tokenizer is missing {missing}")
|
74 |
|
75 |
|
76 |
# ------------------------------------------------------------------
|
77 |
-
# 3. Encoder-only
|
78 |
# ------------------------------------------------------------------
|
79 |
@spaces.GPU
|
80 |
def encode_and_trace(text: str, selected_roles: list[str]):
|
81 |
"""
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
"""
|
87 |
with torch.no_grad():
|
88 |
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
89 |
-
ids,
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
else:
|
107 |
-
|
108 |
|
109 |
-
|
110 |
-
return tokens_str, norm_str, count_int # ← three outputs!
|
111 |
|
112 |
|
113 |
# ------------------------------------------------------------------
|
@@ -116,23 +137,36 @@ def encode_and_trace(text: str, selected_roles: list[str]):
|
|
116 |
def build_interface():
|
117 |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
|
118 |
gr.Markdown(
|
119 |
-
"
|
120 |
-
"
|
121 |
-
"
|
|
|
|
|
122 |
)
|
123 |
|
124 |
with gr.Row():
|
125 |
with gr.Column():
|
126 |
-
txt
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
with gr.Column():
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
return demo
|
138 |
|
|
|
1 |
+
# app.py – encoder-only demo for bert-beatrix-2048 + role-probe
|
2 |
+
# ------------------------------------------------------------
|
3 |
+
# launch: python app.py
|
4 |
+
# (gradio UI appears at http://localhost:7860)
|
5 |
+
|
6 |
+
import json, sys
|
7 |
+
from pathlib import Path, PurePosixPath
|
8 |
|
|
|
|
|
9 |
import gradio as gr
|
10 |
import spaces
|
11 |
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
from huggingface_hub import snapshot_download
|
14 |
|
15 |
from bert_handler import create_handler_from_checkpoint
|
16 |
|
17 |
|
18 |
# ------------------------------------------------------------------
|
19 |
+
# 0. Download & patch config.json --------------------------------
|
20 |
# ------------------------------------------------------------------
|
21 |
+
REPO_ID = "AbstractPhil/bert-beatrix-2048"
|
22 |
+
LOCAL_DIR = "bert-beatrix-2048" # local cache dir
|
23 |
|
24 |
snapshot_download(
|
25 |
repo_id=REPO_ID,
|
26 |
revision="main",
|
27 |
+
local_dir=LOCAL_DIR,
|
28 |
local_dir_use_symlinks=False,
|
29 |
)
|
30 |
|
31 |
+
cfg_path = Path(LOCAL_DIR) / "config.json"
|
32 |
with cfg_path.open() as f:
|
33 |
cfg = json.load(f)
|
34 |
|
35 |
auto_map = cfg.get("auto_map", {})
|
36 |
patched = False
|
37 |
for k, v in auto_map.items():
|
38 |
+
if "--" in v: # e.g. "repo--module.Class"
|
39 |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
|
40 |
patched = True
|
41 |
|
42 |
if patched:
|
|
|
43 |
with cfg_path.open("w") as f:
|
44 |
json.dump(cfg, f, indent=2)
|
45 |
+
print("🛠️ Patched config.json → auto_map fixed.")
|
46 |
|
47 |
|
48 |
# ------------------------------------------------------------------
|
49 |
+
# 1. Model / tokenizer -------------------------------------------
|
50 |
# ------------------------------------------------------------------
|
51 |
+
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
|
52 |
full_model = full_model.eval().cuda()
|
53 |
|
54 |
encoder = full_model.bert.encoder
|
|
|
58 |
|
59 |
|
60 |
# ------------------------------------------------------------------
|
61 |
+
# 2. Symbolic token set ------------------------------------------
|
62 |
# ------------------------------------------------------------------
|
63 |
SYMBOLIC_ROLES = [
|
64 |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
|
|
|
68 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
69 |
"<fabric>", "<jewelry>",
|
70 |
]
|
71 |
+
ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES}
|
72 |
+
missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id]
|
|
|
73 |
if missing:
|
74 |
sys.exit(f"❌ Tokenizer is missing {missing}")
|
75 |
|
76 |
|
77 |
# ------------------------------------------------------------------
|
78 |
+
# 3. Encoder-only + role-similarity probe ------------------------
|
79 |
# ------------------------------------------------------------------
|
80 |
@spaces.GPU
|
81 |
def encode_and_trace(text: str, selected_roles: list[str]):
|
82 |
"""
|
83 |
+
For each *selected* role:
|
84 |
+
• find the contextual token whose hidden state is most similar to that
|
85 |
+
role’s own embedding (cosine similarity)
|
86 |
+
• return “role → token (sim)”, using tokens even when the prompt
|
87 |
+
contained no <role> markers at all.
|
88 |
+
Also keeps the older diagnostics.
|
89 |
"""
|
90 |
with torch.no_grad():
|
91 |
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
92 |
+
ids, mask = batch.input_ids, batch.attention_mask # (1, S)
|
93 |
+
|
94 |
+
# ---------- encoder ----------
|
95 |
+
x = emb_drop(emb_ln(embeddings(ids)))
|
96 |
+
msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
|
97 |
+
h = encoder(x, attention_mask=msk).squeeze(0) # (S, H)
|
98 |
+
|
99 |
+
# L2-normalise hidden states once
|
100 |
+
h_norm = F.normalize(h, dim=-1) # (S, H)
|
101 |
+
|
102 |
+
# ---------- probe each selected role -----------------------
|
103 |
+
matches = []
|
104 |
+
for role in selected_roles:
|
105 |
+
role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device)
|
106 |
+
role_vec = F.normalize(role_vec, dim=-1) # (H)
|
107 |
+
|
108 |
+
sims = (h_norm @ role_vec) # (S)
|
109 |
+
best_idx = int(sims.argmax().item())
|
110 |
+
best_sim = float(sims[best_idx])
|
111 |
+
|
112 |
+
match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx]))
|
113 |
+
matches.append(f"{role} → {match_tok} ({best_sim:.2f})")
|
114 |
+
|
115 |
+
match_str = ", ".join(matches) if matches else "(no roles selected)"
|
116 |
+
|
117 |
+
# ---------- string-match diagnostics -----------------------
|
118 |
+
present = [tok for tok_id, tok in zip(ids[0].tolist(),
|
119 |
+
tokenizer.convert_ids_to_tokens(ids[0]))
|
120 |
+
if tok in selected_roles]
|
121 |
+
present_str = ", ".join(present) or "(none)"
|
122 |
+
count = len(present)
|
123 |
+
|
124 |
+
# ---------- hidden-state norm of *explicit* role tokens ----
|
125 |
+
if count:
|
126 |
+
exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device)
|
127 |
+
norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}"
|
128 |
else:
|
129 |
+
norm_val = "0.0000"
|
130 |
|
131 |
+
return present_str, match_str, norm_val, count
|
|
|
132 |
|
133 |
|
134 |
# ------------------------------------------------------------------
|
|
|
137 |
def build_interface():
|
138 |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
|
139 |
gr.Markdown(
|
140 |
+
"## 🧠 Symbolic Encoder Inspector \n"
|
141 |
+
"Select one or more symbolic *roles* on the left. "
|
142 |
+
"The tool shows which regular tokens (if any) the model thinks "
|
143 |
+
"best fit each role — even when your text doesn’t contain the "
|
144 |
+
"explicit `<role>` marker."
|
145 |
)
|
146 |
|
147 |
with gr.Row():
|
148 |
with gr.Column():
|
149 |
+
txt = gr.Textbox(
|
150 |
+
label="Input text",
|
151 |
+
lines=3,
|
152 |
+
placeholder="Example: A small child in bright red boots jumps over a muddy puddle…",
|
153 |
+
)
|
154 |
+
roles = gr.CheckboxGroup(
|
155 |
+
choices=SYMBOLIC_ROLES,
|
156 |
+
label="Roles to probe",
|
157 |
+
)
|
158 |
+
btn = gr.Button("Run encoder probe")
|
159 |
with gr.Column():
|
160 |
+
out_present = gr.Textbox(label="Explicit role tokens found")
|
161 |
+
out_match = gr.Textbox(label="Role → Best-Match Token (cos θ)")
|
162 |
+
out_norm = gr.Textbox(label="Mean hidden-state norm (explicit)")
|
163 |
+
out_count = gr.Textbox(label="# explicit role tokens")
|
164 |
+
|
165 |
+
btn.click(
|
166 |
+
encode_and_trace,
|
167 |
+
inputs=[txt, roles],
|
168 |
+
outputs=[out_present, out_match, out_norm, out_count],
|
169 |
+
)
|
170 |
|
171 |
return demo
|
172 |
|