Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
# app.py β encoder-only demo for bert-beatrix-2048
|
2 |
-
# ------------------------------------------------------------
|
3 |
# launch: python app.py
|
4 |
-
# (gradio UI appears at http://localhost:7860)
|
5 |
|
6 |
-
import json
|
|
|
7 |
from pathlib import Path, PurePosixPath
|
|
|
8 |
|
9 |
import gradio as gr
|
10 |
import spaces
|
@@ -14,41 +14,40 @@ from huggingface_hub import snapshot_download
|
|
14 |
|
15 |
from bert_handler import create_handler_from_checkpoint
|
16 |
|
17 |
-
|
18 |
# ------------------------------------------------------------------
|
19 |
# 0. Download & patch config.json --------------------------------
|
20 |
# ------------------------------------------------------------------
|
21 |
-
REPO_ID
|
22 |
-
|
23 |
|
24 |
snapshot_download(
|
25 |
repo_id=REPO_ID,
|
26 |
revision="main",
|
27 |
-
local_dir=
|
28 |
local_dir_use_symlinks=False,
|
29 |
)
|
30 |
|
31 |
-
cfg_path = Path(
|
32 |
with cfg_path.open() as f:
|
33 |
cfg = json.load(f)
|
34 |
|
35 |
-
auto_map
|
36 |
-
patched
|
37 |
for k, v in auto_map.items():
|
38 |
-
if "--" in v:
|
39 |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
|
40 |
patched = True
|
41 |
|
42 |
if patched:
|
|
|
43 |
with cfg_path.open("w") as f:
|
44 |
json.dump(cfg, f, indent=2)
|
45 |
-
print("π οΈ Patched config.json β auto_map fixed
|
46 |
-
|
47 |
|
48 |
# ------------------------------------------------------------------
|
49 |
# 1. Model / tokenizer -------------------------------------------
|
50 |
# ------------------------------------------------------------------
|
51 |
-
handler, full_model, tokenizer = create_handler_from_checkpoint(
|
52 |
full_model = full_model.eval().cuda()
|
53 |
|
54 |
encoder = full_model.bert.encoder
|
@@ -56,7 +55,6 @@ embeddings = full_model.bert.embeddings
|
|
56 |
emb_ln = full_model.bert.emb_ln
|
57 |
emb_drop = full_model.bert.emb_drop
|
58 |
|
59 |
-
|
60 |
# ------------------------------------------------------------------
|
61 |
# 2. Symbolic token set ------------------------------------------
|
62 |
# ------------------------------------------------------------------
|
@@ -68,105 +66,111 @@ SYMBOLIC_ROLES = [
|
|
68 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
69 |
"<fabric>", "<jewelry>",
|
70 |
]
|
71 |
-
|
72 |
-
|
|
|
73 |
if missing:
|
74 |
sys.exit(f"β Tokenizer is missing {missing}")
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# ------------------------------------------------------------------
|
78 |
-
#
|
79 |
# ------------------------------------------------------------------
|
80 |
@spaces.GPU
|
81 |
def encode_and_trace(text: str, selected_roles: list[str]):
|
82 |
-
"""
|
83 |
-
For each *selected* role:
|
84 |
-
β’ find the contextual token whose hidden state is most similar to that
|
85 |
-
roleβs own embedding (cosine similarity)
|
86 |
-
β’ return βrole β token (sim)β, using tokens even when the prompt
|
87 |
-
contained no <role> markers at all.
|
88 |
-
Also keeps the older diagnostics.
|
89 |
-
"""
|
90 |
with torch.no_grad():
|
|
|
|
|
|
|
91 |
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
92 |
-
ids,
|
93 |
|
94 |
-
#
|
95 |
-
x
|
96 |
-
|
97 |
-
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
|
102 |
-
|
103 |
-
matches = []
|
104 |
for role in selected_roles:
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
sims
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
if count:
|
126 |
-
exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device)
|
127 |
-
norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}"
|
128 |
-
else:
|
129 |
-
norm_val = "0.0000"
|
130 |
-
|
131 |
-
return present_str, match_str, norm_val, count
|
132 |
-
|
133 |
|
134 |
# ------------------------------------------------------------------
|
135 |
-
#
|
136 |
# ------------------------------------------------------------------
|
137 |
def build_interface():
|
138 |
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo:
|
139 |
gr.Markdown(
|
140 |
-
"
|
141 |
-
"
|
142 |
-
"
|
143 |
-
"
|
144 |
-
"explicit `<role>` marker."
|
145 |
)
|
146 |
|
147 |
with gr.Row():
|
148 |
with gr.Column():
|
149 |
txt = gr.Textbox(
|
150 |
-
label="Input
|
|
|
151 |
lines=3,
|
152 |
-
placeholder="Example: A small child in bright red boots jumps over a muddy puddleβ¦",
|
153 |
)
|
154 |
roles = gr.CheckboxGroup(
|
155 |
choices=SYMBOLIC_ROLES,
|
156 |
-
label="
|
157 |
)
|
158 |
-
btn = gr.Button("
|
159 |
with gr.Column():
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
btn.click(
|
166 |
-
encode_and_trace,
|
167 |
-
inputs=[txt, roles],
|
168 |
-
outputs=[out_present, out_match, out_norm, out_count],
|
169 |
-
)
|
170 |
|
171 |
return demo
|
172 |
|
|
|
1 |
+
# app.py β encoder-only demo for bert-beatrix-2048
|
|
|
2 |
# launch: python app.py
|
|
|
3 |
|
4 |
+
import json
|
5 |
+
import sys
|
6 |
from pathlib import Path, PurePosixPath
|
7 |
+
from itertools import islice
|
8 |
|
9 |
import gradio as gr
|
10 |
import spaces
|
|
|
14 |
|
15 |
from bert_handler import create_handler_from_checkpoint
|
16 |
|
|
|
17 |
# ------------------------------------------------------------------
|
18 |
# 0. Download & patch config.json --------------------------------
|
19 |
# ------------------------------------------------------------------
|
20 |
+
REPO_ID = "AbstractPhil/bert-beatrix-2048"
|
21 |
+
LOCAL_CKPT = "bert-beatrix-2048"
|
22 |
|
23 |
snapshot_download(
|
24 |
repo_id=REPO_ID,
|
25 |
revision="main",
|
26 |
+
local_dir=LOCAL_CKPT,
|
27 |
local_dir_use_symlinks=False,
|
28 |
)
|
29 |
|
30 |
+
cfg_path = Path(LOCAL_CKPT) / "config.json"
|
31 |
with cfg_path.open() as f:
|
32 |
cfg = json.load(f)
|
33 |
|
34 |
+
auto_map = cfg.get("auto_map", {})
|
35 |
+
patched = False
|
36 |
for k, v in auto_map.items():
|
37 |
+
if "--" in v: # "repo--module.Class"
|
38 |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
|
39 |
patched = True
|
40 |
|
41 |
if patched:
|
42 |
+
cfg["auto_map"] = auto_map
|
43 |
with cfg_path.open("w") as f:
|
44 |
json.dump(cfg, f, indent=2)
|
45 |
+
print("π οΈ Patched config.json β auto_map paths fixed")
|
|
|
46 |
|
47 |
# ------------------------------------------------------------------
|
48 |
# 1. Model / tokenizer -------------------------------------------
|
49 |
# ------------------------------------------------------------------
|
50 |
+
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
|
51 |
full_model = full_model.eval().cuda()
|
52 |
|
53 |
encoder = full_model.bert.encoder
|
|
|
55 |
emb_ln = full_model.bert.emb_ln
|
56 |
emb_drop = full_model.bert.emb_drop
|
57 |
|
|
|
58 |
# ------------------------------------------------------------------
|
59 |
# 2. Symbolic token set ------------------------------------------
|
60 |
# ------------------------------------------------------------------
|
|
|
66 |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
|
67 |
"<fabric>", "<jewelry>",
|
68 |
]
|
69 |
+
|
70 |
+
unk = tokenizer.unk_token_id
|
71 |
+
missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
|
72 |
if missing:
|
73 |
sys.exit(f"β Tokenizer is missing {missing}")
|
74 |
|
75 |
+
# ------------------------------------------------------------------
|
76 |
+
# 3. helper: merge lowest + highest until 3 remain ----------------
|
77 |
+
# ------------------------------------------------------------------
|
78 |
+
def reduce_to_three(table):
|
79 |
+
"""
|
80 |
+
table : list of dicts {role, token, score}
|
81 |
+
repeatedly remove lowest and highest,
|
82 |
+
replace with their average,
|
83 |
+
until len(table)==3.
|
84 |
+
"""
|
85 |
+
working = table[:]
|
86 |
+
working.sort(key=lambda x: x["score"])
|
87 |
+
while len(working) > 3:
|
88 |
+
low = working.pop(0)
|
89 |
+
high = working.pop(-1)
|
90 |
+
merged = {
|
91 |
+
"role": f"{high['role']}|{low['role']}",
|
92 |
+
"token": f"{high['token']}/{low['token']}",
|
93 |
+
"score": (high["score"] + low["score"]) / 2.0,
|
94 |
+
}
|
95 |
+
working.append(merged)
|
96 |
+
working.sort(key=lambda x: x["score"])
|
97 |
+
# highest first for display
|
98 |
+
working.sort(key=lambda x: x["score"], reverse=True)
|
99 |
+
return working
|
100 |
|
101 |
# ------------------------------------------------------------------
|
102 |
+
# 4. Encoder-only inference util ---------------------------------
|
103 |
# ------------------------------------------------------------------
|
104 |
@spaces.GPU
|
105 |
def encode_and_trace(text: str, selected_roles: list[str]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
with torch.no_grad():
|
107 |
+
if not text.strip():
|
108 |
+
return "(no input)","",""
|
109 |
+
|
110 |
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
111 |
+
ids, attn = batch.input_ids, batch.attention_mask
|
112 |
|
113 |
+
# encoder forward
|
114 |
+
x = emb_drop(emb_ln(embeddings(ids)))
|
115 |
+
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
|
116 |
+
hs = encoder(x, attention_mask=ext) # (1,S,H)
|
117 |
|
118 |
+
# token-level embeddings (before LN) for similarity calc
|
119 |
+
token_emb = embeddings(ids).squeeze(0) # (S,H)
|
120 |
|
121 |
+
rows = []
|
|
|
122 |
for role in selected_roles:
|
123 |
+
rid = tokenizer.convert_tokens_to_ids(role)
|
124 |
+
rvec = embeddings.weight[rid] # (H,)
|
125 |
+
# cosine similarity to every *input* token embedding
|
126 |
+
sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
|
127 |
+
best = torch.argmax(sims).item()
|
128 |
+
rows.append({
|
129 |
+
"role" : role,
|
130 |
+
"token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
|
131 |
+
"score": sims[best].item()
|
132 |
+
})
|
133 |
+
|
134 |
+
if not rows:
|
135 |
+
return "(none selected)","",""
|
136 |
+
|
137 |
+
final3 = reduce_to_three(rows)
|
138 |
+
out_strs = [f"{r['role']} β {r['token']} ({r['score']:+.2f})" for r in final3]
|
139 |
+
# pad so we always return 3 strings
|
140 |
+
while len(out_strs) < 3:
|
141 |
+
out_strs.append("")
|
142 |
+
return out_strs[0], out_strs[1], out_strs[2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# ------------------------------------------------------------------
|
145 |
+
# 5. Gradio UI ----------------------------------------------------
|
146 |
# ------------------------------------------------------------------
|
147 |
def build_interface():
|
148 |
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo:
|
149 |
gr.Markdown(
|
150 |
+
"### π§ Symbolic Encoder Inspector\n"
|
151 |
+
"Paste text with `<role>` tokens, pick roles to track, then we\n"
|
152 |
+
"β’ compute role β token cosine scores\n"
|
153 |
+
"β’ iteratively merge low+high pairs until **3 composite buckets** remain."
|
|
|
154 |
)
|
155 |
|
156 |
with gr.Row():
|
157 |
with gr.Column():
|
158 |
txt = gr.Textbox(
|
159 |
+
label="Input with Symbolic Tokens",
|
160 |
+
placeholder="Example: A <subject> wearing <upper_body_clothing> β¦",
|
161 |
lines=3,
|
|
|
162 |
)
|
163 |
roles = gr.CheckboxGroup(
|
164 |
choices=SYMBOLIC_ROLES,
|
165 |
+
label="Trace these symbolic roles",
|
166 |
)
|
167 |
+
btn = gr.Button("Encode & Merge")
|
168 |
with gr.Column():
|
169 |
+
cat1 = gr.Textbox(label="Category 1 (highest)")
|
170 |
+
cat2 = gr.Textbox(label="Category 2")
|
171 |
+
cat3 = gr.Textbox(label="Category 3 (lowest)")
|
172 |
+
|
173 |
+
btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
return demo
|
176 |
|