AbstractPhil commited on
Commit
aa28bbb
·
verified ·
1 Parent(s): 415afa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -56
app.py CHANGED
@@ -1,52 +1,54 @@
1
- # app.py – encoder-only demo for bert-beatrix-2048
2
- # ------------------------------------------------
3
- # launch: python app.py → http://localhost:7860
 
 
 
 
4
 
5
- import json, re, sys
6
- from pathlib import Path, PurePosixPath # ← PurePosixPath import
7
  import gradio as gr
8
  import spaces
9
  import torch
 
10
  from huggingface_hub import snapshot_download
11
 
12
  from bert_handler import create_handler_from_checkpoint
13
 
14
 
15
  # ------------------------------------------------------------------
16
- # 0. Download & patch config.json ---------------------------------
17
  # ------------------------------------------------------------------
18
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
19
- LOCAL_CKPT = "bert-beatrix-2048" # cache dir name
20
 
21
  snapshot_download(
22
  repo_id=REPO_ID,
23
  revision="main",
24
- local_dir=LOCAL_CKPT,
25
  local_dir_use_symlinks=False,
26
  )
27
 
28
- cfg_path = Path(LOCAL_CKPT) / "config.json"
29
  with cfg_path.open() as f:
30
  cfg = json.load(f)
31
 
32
  auto_map = cfg.get("auto_map", {})
33
  patched = False
34
  for k, v in auto_map.items():
35
- if "--" in v: # strip repo--module.Class
36
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
37
  patched = True
38
 
39
  if patched:
40
- cfg["auto_map"] = auto_map
41
  with cfg_path.open("w") as f:
42
  json.dump(cfg, f, indent=2)
43
- print("🛠️ Patched config.json → auto_map now points to local modules")
44
 
45
 
46
  # ------------------------------------------------------------------
47
- # 1. Load model / tokenizer ---------------------------------------
48
  # ------------------------------------------------------------------
49
- handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
50
  full_model = full_model.eval().cuda()
51
 
52
  encoder = full_model.bert.encoder
@@ -56,7 +58,7 @@ emb_drop = full_model.bert.emb_drop
56
 
57
 
58
  # ------------------------------------------------------------------
59
- # 2. Symbolic roles ------------------------------------------------
60
  # ------------------------------------------------------------------
61
  SYMBOLIC_ROLES = [
62
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
@@ -66,48 +68,67 @@ SYMBOLIC_ROLES = [
66
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
67
  "<fabric>", "<jewelry>",
68
  ]
69
-
70
- missing = [t for t in SYMBOLIC_ROLES
71
- if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
72
  if missing:
73
  sys.exit(f"❌ Tokenizer is missing {missing}")
74
 
75
 
76
  # ------------------------------------------------------------------
77
- # 3. Encoder-only helper ------------------------------------------
78
  # ------------------------------------------------------------------
79
  @spaces.GPU
80
  def encode_and_trace(text: str, selected_roles: list[str]):
81
  """
82
- Returns **exactly three scalars** matching the 3 gradio outputs:
83
- 1) tokens_str (comma-separated list found)
84
- 2) norm_str (mean L2-norm of those embeddings)
85
- 3) count_int (# tokens matched)
 
 
86
  """
87
  with torch.no_grad():
88
  batch = tokenizer(text, return_tensors="pt").to("cuda")
89
- ids, attn = batch.input_ids, batch.attention_mask
90
-
91
- x = emb_drop(emb_ln(embeddings(ids)))
92
- ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
93
- enc = encoder(x, attention_mask=ext) # (1, S, H)
94
-
95
- role_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
96
- mask = torch.tensor([tid in role_ids for tid in ids[0].tolist()],
97
- device=enc.device, dtype=torch.bool)
98
-
99
- found = [tokenizer.convert_ids_to_tokens([tid])[0]
100
- for tid in ids[0].tolist() if tid in role_ids]
101
- tokens_str = ", ".join(found) or "(none)"
102
-
103
- if mask.any():
104
- mean_vec = enc[0][mask].mean(0)
105
- norm_str = f"{mean_vec.norm().item():.4f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  else:
107
- norm_str = "0.0000"
108
 
109
- count_int = int(mask.sum().item())
110
- return tokens_str, norm_str, count_int # ← three outputs!
111
 
112
 
113
  # ------------------------------------------------------------------
@@ -116,23 +137,36 @@ def encode_and_trace(text: str, selected_roles: list[str]):
116
  def build_interface():
117
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
118
  gr.Markdown(
119
- "### 🧠 Symbolic Encoder Inspector\n"
120
- "Paste text that includes the `<role>` tokens and inspect their "
121
- "hidden-state statistics."
 
 
122
  )
123
 
124
  with gr.Row():
125
  with gr.Column():
126
- txt = gr.Textbox(label="Input", lines=3,
127
- placeholder="A <subject> wearing <upper_body_clothing> …")
128
- chk = gr.CheckboxGroup(SYMBOLIC_ROLES, label="Roles to trace")
129
- run = gr.Button("Encode & Trace")
 
 
 
 
 
 
130
  with gr.Column():
131
- out_tok = gr.Textbox(label="Tokens found")
132
- out_norm = gr.Textbox(label="Mean norm")
133
- out_cnt = gr.Textbox(label="Token count")
134
-
135
- run.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
 
 
 
 
 
136
 
137
  return demo
138
 
 
1
+ # app.py – encoder-only demo for bert-beatrix-2048 + role-probe
2
+ # ------------------------------------------------------------
3
+ # launch: python app.py
4
+ # (gradio UI appears at http://localhost:7860)
5
+
6
+ import json, sys
7
+ from pathlib import Path, PurePosixPath
8
 
 
 
9
  import gradio as gr
10
  import spaces
11
  import torch
12
+ import torch.nn.functional as F
13
  from huggingface_hub import snapshot_download
14
 
15
  from bert_handler import create_handler_from_checkpoint
16
 
17
 
18
  # ------------------------------------------------------------------
19
+ # 0. Download & patch config.json --------------------------------
20
  # ------------------------------------------------------------------
21
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
22
+ LOCAL_DIR = "bert-beatrix-2048" # local cache dir
23
 
24
  snapshot_download(
25
  repo_id=REPO_ID,
26
  revision="main",
27
+ local_dir=LOCAL_DIR,
28
  local_dir_use_symlinks=False,
29
  )
30
 
31
+ cfg_path = Path(LOCAL_DIR) / "config.json"
32
  with cfg_path.open() as f:
33
  cfg = json.load(f)
34
 
35
  auto_map = cfg.get("auto_map", {})
36
  patched = False
37
  for k, v in auto_map.items():
38
+ if "--" in v: # e.g. "repo--module.Class"
39
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
40
  patched = True
41
 
42
  if patched:
 
43
  with cfg_path.open("w") as f:
44
  json.dump(cfg, f, indent=2)
45
+ print("🛠️ Patched config.json → auto_map fixed.")
46
 
47
 
48
  # ------------------------------------------------------------------
49
+ # 1. Model / tokenizer -------------------------------------------
50
  # ------------------------------------------------------------------
51
+ handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
52
  full_model = full_model.eval().cuda()
53
 
54
  encoder = full_model.bert.encoder
 
58
 
59
 
60
  # ------------------------------------------------------------------
61
+ # 2. Symbolic token set ------------------------------------------
62
  # ------------------------------------------------------------------
63
  SYMBOLIC_ROLES = [
64
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
 
68
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
69
  "<fabric>", "<jewelry>",
70
  ]
71
+ ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES}
72
+ missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id]
 
73
  if missing:
74
  sys.exit(f"❌ Tokenizer is missing {missing}")
75
 
76
 
77
  # ------------------------------------------------------------------
78
+ # 3. Encoder-only + role-similarity probe ------------------------
79
  # ------------------------------------------------------------------
80
  @spaces.GPU
81
  def encode_and_trace(text: str, selected_roles: list[str]):
82
  """
83
+ For each *selected* role:
84
+ find the contextual token whose hidden state is most similar to that
85
+ role’s own embedding (cosine similarity)
86
+ return “role → token (sim)”, using tokens even when the prompt
87
+ contained no <role> markers at all.
88
+ Also keeps the older diagnostics.
89
  """
90
  with torch.no_grad():
91
  batch = tokenizer(text, return_tensors="pt").to("cuda")
92
+ ids, mask = batch.input_ids, batch.attention_mask # (1, S)
93
+
94
+ # ---------- encoder ----------
95
+ x = emb_drop(emb_ln(embeddings(ids)))
96
+ msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
97
+ h = encoder(x, attention_mask=msk).squeeze(0) # (S, H)
98
+
99
+ # L2-normalise hidden states once
100
+ h_norm = F.normalize(h, dim=-1) # (S, H)
101
+
102
+ # ---------- probe each selected role -----------------------
103
+ matches = []
104
+ for role in selected_roles:
105
+ role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device)
106
+ role_vec = F.normalize(role_vec, dim=-1) # (H)
107
+
108
+ sims = (h_norm @ role_vec) # (S)
109
+ best_idx = int(sims.argmax().item())
110
+ best_sim = float(sims[best_idx])
111
+
112
+ match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx]))
113
+ matches.append(f"{role} → {match_tok} ({best_sim:.2f})")
114
+
115
+ match_str = ", ".join(matches) if matches else "(no roles selected)"
116
+
117
+ # ---------- string-match diagnostics -----------------------
118
+ present = [tok for tok_id, tok in zip(ids[0].tolist(),
119
+ tokenizer.convert_ids_to_tokens(ids[0]))
120
+ if tok in selected_roles]
121
+ present_str = ", ".join(present) or "(none)"
122
+ count = len(present)
123
+
124
+ # ---------- hidden-state norm of *explicit* role tokens ----
125
+ if count:
126
+ exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device)
127
+ norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}"
128
  else:
129
+ norm_val = "0.0000"
130
 
131
+ return present_str, match_str, norm_val, count
 
132
 
133
 
134
  # ------------------------------------------------------------------
 
137
  def build_interface():
138
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
139
  gr.Markdown(
140
+ "## 🧠 Symbolic Encoder Inspector \n"
141
+ "Select one or more symbolic *roles* on the left. "
142
+ "The tool shows which regular tokens (if any) the model thinks "
143
+ "best fit each role — even when your text doesn’t contain the "
144
+ "explicit `<role>` marker."
145
  )
146
 
147
  with gr.Row():
148
  with gr.Column():
149
+ txt = gr.Textbox(
150
+ label="Input text",
151
+ lines=3,
152
+ placeholder="Example: A small child in bright red boots jumps over a muddy puddle…",
153
+ )
154
+ roles = gr.CheckboxGroup(
155
+ choices=SYMBOLIC_ROLES,
156
+ label="Roles to probe",
157
+ )
158
+ btn = gr.Button("Run encoder probe")
159
  with gr.Column():
160
+ out_present = gr.Textbox(label="Explicit role tokens found")
161
+ out_match = gr.Textbox(label="Role → Best-Match Token (cos θ)")
162
+ out_norm = gr.Textbox(label="Mean hidden-state norm (explicit)")
163
+ out_count = gr.Textbox(label="# explicit role tokens")
164
+
165
+ btn.click(
166
+ encode_and_trace,
167
+ inputs=[txt, roles],
168
+ outputs=[out_present, out_match, out_norm, out_count],
169
+ )
170
 
171
  return demo
172