AbstractPhil commited on
Commit
ed080e6
Β·
verified Β·
1 Parent(s): aa28bbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -84
app.py CHANGED
@@ -1,10 +1,10 @@
1
- # app.py – encoder-only demo for bert-beatrix-2048 + role-probe
2
- # ------------------------------------------------------------
3
  # launch: python app.py
4
- # (gradio UI appears at http://localhost:7860)
5
 
6
- import json, sys
 
7
  from pathlib import Path, PurePosixPath
 
8
 
9
  import gradio as gr
10
  import spaces
@@ -14,41 +14,40 @@ from huggingface_hub import snapshot_download
14
 
15
  from bert_handler import create_handler_from_checkpoint
16
 
17
-
18
  # ------------------------------------------------------------------
19
  # 0. Download & patch config.json --------------------------------
20
  # ------------------------------------------------------------------
21
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
22
- LOCAL_DIR = "bert-beatrix-2048" # local cache dir
23
 
24
  snapshot_download(
25
  repo_id=REPO_ID,
26
  revision="main",
27
- local_dir=LOCAL_DIR,
28
  local_dir_use_symlinks=False,
29
  )
30
 
31
- cfg_path = Path(LOCAL_DIR) / "config.json"
32
  with cfg_path.open() as f:
33
  cfg = json.load(f)
34
 
35
- auto_map = cfg.get("auto_map", {})
36
- patched = False
37
  for k, v in auto_map.items():
38
- if "--" in v: # e.g. "repo--module.Class"
39
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
40
  patched = True
41
 
42
  if patched:
 
43
  with cfg_path.open("w") as f:
44
  json.dump(cfg, f, indent=2)
45
- print("πŸ› οΈ Patched config.json β†’ auto_map fixed.")
46
-
47
 
48
  # ------------------------------------------------------------------
49
  # 1. Model / tokenizer -------------------------------------------
50
  # ------------------------------------------------------------------
51
- handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
52
  full_model = full_model.eval().cuda()
53
 
54
  encoder = full_model.bert.encoder
@@ -56,7 +55,6 @@ embeddings = full_model.bert.embeddings
56
  emb_ln = full_model.bert.emb_ln
57
  emb_drop = full_model.bert.emb_drop
58
 
59
-
60
  # ------------------------------------------------------------------
61
  # 2. Symbolic token set ------------------------------------------
62
  # ------------------------------------------------------------------
@@ -68,105 +66,111 @@ SYMBOLIC_ROLES = [
68
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
69
  "<fabric>", "<jewelry>",
70
  ]
71
- ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES}
72
- missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id]
 
73
  if missing:
74
  sys.exit(f"❌ Tokenizer is missing {missing}")
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # ------------------------------------------------------------------
78
- # 3. Encoder-only + role-similarity probe ------------------------
79
  # ------------------------------------------------------------------
80
  @spaces.GPU
81
  def encode_and_trace(text: str, selected_roles: list[str]):
82
- """
83
- For each *selected* role:
84
- β€’ find the contextual token whose hidden state is most similar to that
85
- role’s own embedding (cosine similarity)
86
- β€’ return β€œrole β†’ token (sim)”, using tokens even when the prompt
87
- contained no <role> markers at all.
88
- Also keeps the older diagnostics.
89
- """
90
  with torch.no_grad():
 
 
 
91
  batch = tokenizer(text, return_tensors="pt").to("cuda")
92
- ids, mask = batch.input_ids, batch.attention_mask # (1, S)
93
 
94
- # ---------- encoder ----------
95
- x = emb_drop(emb_ln(embeddings(ids)))
96
- msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
97
- h = encoder(x, attention_mask=msk).squeeze(0) # (S, H)
98
 
99
- # L2-normalise hidden states once
100
- h_norm = F.normalize(h, dim=-1) # (S, H)
101
 
102
- # ---------- probe each selected role -----------------------
103
- matches = []
104
  for role in selected_roles:
105
- role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device)
106
- role_vec = F.normalize(role_vec, dim=-1) # (H)
107
-
108
- sims = (h_norm @ role_vec) # (S)
109
- best_idx = int(sims.argmax().item())
110
- best_sim = float(sims[best_idx])
111
-
112
- match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx]))
113
- matches.append(f"{role} β†’ {match_tok} ({best_sim:.2f})")
114
-
115
- match_str = ", ".join(matches) if matches else "(no roles selected)"
116
-
117
- # ---------- string-match diagnostics -----------------------
118
- present = [tok for tok_id, tok in zip(ids[0].tolist(),
119
- tokenizer.convert_ids_to_tokens(ids[0]))
120
- if tok in selected_roles]
121
- present_str = ", ".join(present) or "(none)"
122
- count = len(present)
123
-
124
- # ---------- hidden-state norm of *explicit* role tokens ----
125
- if count:
126
- exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device)
127
- norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}"
128
- else:
129
- norm_val = "0.0000"
130
-
131
- return present_str, match_str, norm_val, count
132
-
133
 
134
  # ------------------------------------------------------------------
135
- # 4. Gradio UI ----------------------------------------------------
136
  # ------------------------------------------------------------------
137
  def build_interface():
138
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
139
  gr.Markdown(
140
- "## 🧠 Symbolic Encoder Inspector \n"
141
- "Select one or more symbolic *roles* on the left. "
142
- "The tool shows which regular tokens (if any) the model thinks "
143
- "best fit each role β€” even when your text doesn’t contain the "
144
- "explicit `<role>` marker."
145
  )
146
 
147
  with gr.Row():
148
  with gr.Column():
149
  txt = gr.Textbox(
150
- label="Input text",
 
151
  lines=3,
152
- placeholder="Example: A small child in bright red boots jumps over a muddy puddle…",
153
  )
154
  roles = gr.CheckboxGroup(
155
  choices=SYMBOLIC_ROLES,
156
- label="Roles to probe",
157
  )
158
- btn = gr.Button("Run encoder probe")
159
  with gr.Column():
160
- out_present = gr.Textbox(label="Explicit role tokens found")
161
- out_match = gr.Textbox(label="Role β†’ Best-Match Token (cos ΞΈ)")
162
- out_norm = gr.Textbox(label="Mean hidden-state norm (explicit)")
163
- out_count = gr.Textbox(label="# explicit role tokens")
164
-
165
- btn.click(
166
- encode_and_trace,
167
- inputs=[txt, roles],
168
- outputs=[out_present, out_match, out_norm, out_count],
169
- )
170
 
171
  return demo
172
 
 
1
+ # app.py – encoder-only demo for bert-beatrix-2048
 
2
  # launch: python app.py
 
3
 
4
+ import json
5
+ import sys
6
  from pathlib import Path, PurePosixPath
7
+ from itertools import islice
8
 
9
  import gradio as gr
10
  import spaces
 
14
 
15
  from bert_handler import create_handler_from_checkpoint
16
 
 
17
  # ------------------------------------------------------------------
18
  # 0. Download & patch config.json --------------------------------
19
  # ------------------------------------------------------------------
20
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
21
+ LOCAL_CKPT = "bert-beatrix-2048"
22
 
23
  snapshot_download(
24
  repo_id=REPO_ID,
25
  revision="main",
26
+ local_dir=LOCAL_CKPT,
27
  local_dir_use_symlinks=False,
28
  )
29
 
30
+ cfg_path = Path(LOCAL_CKPT) / "config.json"
31
  with cfg_path.open() as f:
32
  cfg = json.load(f)
33
 
34
+ auto_map = cfg.get("auto_map", {})
35
+ patched = False
36
  for k, v in auto_map.items():
37
+ if "--" in v: # "repo--module.Class"
38
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
39
  patched = True
40
 
41
  if patched:
42
+ cfg["auto_map"] = auto_map
43
  with cfg_path.open("w") as f:
44
  json.dump(cfg, f, indent=2)
45
+ print("πŸ› οΈ Patched config.json β†’ auto_map paths fixed")
 
46
 
47
  # ------------------------------------------------------------------
48
  # 1. Model / tokenizer -------------------------------------------
49
  # ------------------------------------------------------------------
50
+ handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
51
  full_model = full_model.eval().cuda()
52
 
53
  encoder = full_model.bert.encoder
 
55
  emb_ln = full_model.bert.emb_ln
56
  emb_drop = full_model.bert.emb_drop
57
 
 
58
  # ------------------------------------------------------------------
59
  # 2. Symbolic token set ------------------------------------------
60
  # ------------------------------------------------------------------
 
66
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
67
  "<fabric>", "<jewelry>",
68
  ]
69
+
70
+ unk = tokenizer.unk_token_id
71
+ missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
72
  if missing:
73
  sys.exit(f"❌ Tokenizer is missing {missing}")
74
 
75
+ # ------------------------------------------------------------------
76
+ # 3. helper: merge lowest + highest until 3 remain ----------------
77
+ # ------------------------------------------------------------------
78
+ def reduce_to_three(table):
79
+ """
80
+ table : list of dicts {role, token, score}
81
+ repeatedly remove lowest and highest,
82
+ replace with their average,
83
+ until len(table)==3.
84
+ """
85
+ working = table[:]
86
+ working.sort(key=lambda x: x["score"])
87
+ while len(working) > 3:
88
+ low = working.pop(0)
89
+ high = working.pop(-1)
90
+ merged = {
91
+ "role": f"{high['role']}|{low['role']}",
92
+ "token": f"{high['token']}/{low['token']}",
93
+ "score": (high["score"] + low["score"]) / 2.0,
94
+ }
95
+ working.append(merged)
96
+ working.sort(key=lambda x: x["score"])
97
+ # highest first for display
98
+ working.sort(key=lambda x: x["score"], reverse=True)
99
+ return working
100
 
101
  # ------------------------------------------------------------------
102
+ # 4. Encoder-only inference util ---------------------------------
103
  # ------------------------------------------------------------------
104
  @spaces.GPU
105
  def encode_and_trace(text: str, selected_roles: list[str]):
 
 
 
 
 
 
 
 
106
  with torch.no_grad():
107
+ if not text.strip():
108
+ return "(no input)","",""
109
+
110
  batch = tokenizer(text, return_tensors="pt").to("cuda")
111
+ ids, attn = batch.input_ids, batch.attention_mask
112
 
113
+ # encoder forward
114
+ x = emb_drop(emb_ln(embeddings(ids)))
115
+ ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
116
+ hs = encoder(x, attention_mask=ext) # (1,S,H)
117
 
118
+ # token-level embeddings (before LN) for similarity calc
119
+ token_emb = embeddings(ids).squeeze(0) # (S,H)
120
 
121
+ rows = []
 
122
  for role in selected_roles:
123
+ rid = tokenizer.convert_tokens_to_ids(role)
124
+ rvec = embeddings.weight[rid] # (H,)
125
+ # cosine similarity to every *input* token embedding
126
+ sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
127
+ best = torch.argmax(sims).item()
128
+ rows.append({
129
+ "role" : role,
130
+ "token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
131
+ "score": sims[best].item()
132
+ })
133
+
134
+ if not rows:
135
+ return "(none selected)","",""
136
+
137
+ final3 = reduce_to_three(rows)
138
+ out_strs = [f"{r['role']} ↔ {r['token']} ({r['score']:+.2f})" for r in final3]
139
+ # pad so we always return 3 strings
140
+ while len(out_strs) < 3:
141
+ out_strs.append("")
142
+ return out_strs[0], out_strs[1], out_strs[2]
 
 
 
 
 
 
 
 
143
 
144
  # ------------------------------------------------------------------
145
+ # 5. Gradio UI ----------------------------------------------------
146
  # ------------------------------------------------------------------
147
  def build_interface():
148
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
149
  gr.Markdown(
150
+ "### 🧠 Symbolic Encoder Inspector\n"
151
+ "Paste text with `<role>` tokens, pick roles to track, then we\n"
152
+ "β€’ compute role ↔ token cosine scores\n"
153
+ "β€’ iteratively merge low+high pairs until **3 composite buckets** remain."
 
154
  )
155
 
156
  with gr.Row():
157
  with gr.Column():
158
  txt = gr.Textbox(
159
+ label="Input with Symbolic Tokens",
160
+ placeholder="Example: A <subject> wearing <upper_body_clothing> …",
161
  lines=3,
 
162
  )
163
  roles = gr.CheckboxGroup(
164
  choices=SYMBOLIC_ROLES,
165
+ label="Trace these symbolic roles",
166
  )
167
+ btn = gr.Button("Encode & Merge")
168
  with gr.Column():
169
+ cat1 = gr.Textbox(label="Category 1 (highest)")
170
+ cat2 = gr.Textbox(label="Category 2")
171
+ cat3 = gr.Textbox(label="Category 3 (lowest)")
172
+
173
+ btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])
 
 
 
 
 
174
 
175
  return demo
176