AbstractPhil commited on
Commit
fd4a12a
·
verified ·
1 Parent(s): 5e20a2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -115
app.py CHANGED
@@ -1,57 +1,53 @@
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
- # ------------------------------------------------------------------
3
  # launch: python app.py
4
- # ------------------------------------------------------------------
5
-
6
- import json, re, sys
7
  from pathlib import Path, PurePosixPath
8
 
 
9
  import gradio as gr
10
  import spaces
11
- import torch
12
  from huggingface_hub import snapshot_download
13
 
14
  from bert_handler import create_handler_from_checkpoint
15
 
16
 
17
  # ------------------------------------------------------------------
18
- # 0. Download & patch config.json --------------------------------
19
- # ------------------------------------------------------------------
20
  REPO_ID = "AbstractPhil/bert-beatrix-2048"
21
- LOCAL_DIR = "bert-beatrix-2048"
22
 
23
- snapshot_download(REPO_ID, revision="main",
24
- local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
 
 
 
 
25
 
26
- cfg_path = Path(LOCAL_DIR) / "config.json"
27
- cfg = json.loads(cfg_path.read_text())
 
28
 
29
- auto_map, changed = cfg.get("auto_map", {}), False
30
- for k, v in auto_map.items():
31
  if "--" in v:
32
- auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
33
- changed = True
34
- if changed:
35
- cfg["auto_map"] = auto_map
36
- cfg_path.write_text(json.dumps(cfg, indent=2))
37
- print("🛠️ Patched config.json → auto_map now points at local modules")
38
 
39
-
40
- # ------------------------------------------------------------------
41
- # 1. Model / tokenizer -------------------------------------------
42
  # ------------------------------------------------------------------
43
- handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
 
44
  full_model = full_model.eval().cuda()
45
 
46
- encoder = full_model.bert.encoder
47
- embeddings = full_model.bert.embeddings
48
- emb_ln = full_model.bert.emb_ln
49
- emb_drop = full_model.bert.emb_drop
 
50
 
51
-
52
- # ------------------------------------------------------------------
53
- # 2. Symbolic token set ------------------------------------------
54
  # ------------------------------------------------------------------
 
55
  SYMBOLIC_ROLES = [
56
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
57
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
@@ -60,108 +56,124 @@ SYMBOLIC_ROLES = [
60
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
61
  "<fabric>", "<jewelry>",
62
  ]
 
 
 
63
 
64
- missing = [t for t in SYMBOLIC_ROLES
65
- if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
66
- if missing:
67
- sys.exit(f"❌ Tokenizer is missing {missing}")
68
 
69
 
70
  # ------------------------------------------------------------------
71
- # 3. Encoder + *mask-inference* util ------------------------------
72
- # ------------------------------------------------------------------
73
- MASK = tokenizer.mask_token or "[MASK]"
74
 
75
- @spaces.GPU
76
- def encode_and_trace(text: str, _ignored): # all roles auto-selected
77
  """
78
- 1. run encoder pass cosine report (as before)
79
- 2. mask **every** symbolic token one-at-a-time
80
- and ask the full model to predict it back.
81
- Accuracy over those positions is returned.
82
  """
83
- if not text.strip():
84
- return "(empty)", "0.0000", 0, "0 / 0 (0.0%)"
85
-
86
- with torch.no_grad():
87
- # -------- ENCODER PROBE (unchanged) ------------------
88
- batch = tokenizer(text, return_tensors="pt").to("cuda")
89
- ids, mask = batch.input_ids, batch.attention_mask
90
-
91
- x = emb_drop(emb_ln(embeddings(ids)))
92
- am = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
93
- enc = encoder(x, attention_mask=am) # (1,S,H)
94
-
95
- sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in SYMBOLIC_ROLES}
96
- flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
97
- device=enc.device)
98
-
99
- found = [tokenizer.convert_ids_to_tokens([tid])[0]
100
- for tid in ids[0].tolist() if tid in sel_ids]
101
- tokens_str = ", ".join(found) if found else "(none)"
102
-
103
- if flags.any():
104
- vec = enc[0][flags].mean(0)
105
- norm = f"{vec.norm().item():.4f}"
106
- else:
107
- norm = "0.0000"
108
-
109
- # -------- MASK-AND-PREDICT ACCURACY ------------------
110
- correct, total = 0, 0
111
- for pos, tid in enumerate(ids[0].tolist()):
112
- if tid in sel_ids: # symbolic
113
- total += 1
114
- masked_ids = ids.clone()
115
- masked_ids[0, pos] = tokenizer.mask_token_id
116
- out = full_model(input_ids=masked_ids,
117
- attention_mask=mask).logits # (1,S,V)
118
- pred = out[0, pos].argmax(-1).item()
119
- if pred == tid:
120
- correct += 1
121
-
122
- acc_str = f"{correct} / {total} ({(correct/total*100 if total else 0):.1f}%)"
123
-
124
- return tokens_str, norm, len(found), acc_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
  # ------------------------------------------------------------------
128
- # 4. Gradio UI ----------------------------------------------------
129
- # ------------------------------------------------------------------
130
  def build_interface():
131
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
132
- gr.Markdown(
133
- "## 🧠 Symbolic Encoder Inspector\n"
134
- "Enter text containing the `<role>` tokens.\n"
135
- "Cosine probe **and** real mask-prediction accuracy are shown."
136
- )
137
 
138
  with gr.Row():
139
  with gr.Column():
140
- txt = gr.Textbox(
141
- label="Input with Symbolic Tokens",
142
- placeholder="Example: A <subject> wearing <upper_body_clothing> …",
143
- lines=3,
144
  )
145
- # checkbox group kept (pre-checked, disabled)
146
- roles = gr.CheckboxGroup(
147
- choices=SYMBOLIC_ROLES,
148
- label="(all roles auto-selected)",
149
- value=SYMBOLIC_ROLES,
150
- interactive=False,
151
- )
152
- btn = gr.Button("Run probe + MLM check")
153
  with gr.Column():
154
- out_tok = gr.Textbox(label="Symbolic Tokens Found")
155
- out_norm = gr.Textbox(label="Vector-norm (mean)")
156
- out_cnt = gr.Textbox(label="Token Count")
157
- out_acc = gr.Textbox(label="Mask-prediction accuracy")
158
-
159
- btn.click(encode_and_trace,
160
- inputs=[txt, roles],
161
- outputs=[out_tok, out_norm, out_cnt, out_acc])
162
 
 
163
  return demo
164
 
165
 
166
- if __name__ == "__main__":
167
- build_interface().launch()
 
1
  # app.py – encoder-only demo for bert-beatrix-2048
 
2
  # launch: python app.py
3
+ # -----------------------------------------------
4
+ import json, re, sys, math
 
5
  from pathlib import Path, PurePosixPath
6
 
7
+ import torch, torch.nn.functional as F
8
  import gradio as gr
9
  import spaces
 
10
  from huggingface_hub import snapshot_download
11
 
12
  from bert_handler import create_handler_from_checkpoint
13
 
14
 
15
  # ------------------------------------------------------------------
16
+ # 0. Download & patch HF checkpoint --------------------------------
 
17
  REPO_ID = "AbstractPhil/bert-beatrix-2048"
18
+ LOCAL_CKPT = "bert-beatrix-2048"
19
 
20
+ snapshot_download(
21
+ repo_id=REPO_ID,
22
+ revision="main",
23
+ local_dir=LOCAL_CKPT,
24
+ local_dir_use_symlinks=False,
25
+ )
26
 
27
+ # strip repo prefix in auto_map (one-time)
28
+ cfg_path = Path(LOCAL_CKPT) / "config.json"
29
+ with cfg_path.open() as f: cfg = json.load(f)
30
 
31
+ amap = cfg.get("auto_map", {})
32
+ for k,v in amap.items():
33
  if "--" in v:
34
+ amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
35
+ cfg["auto_map"] = amap
36
+ with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
 
 
 
37
 
 
 
 
38
  # ------------------------------------------------------------------
39
+ # 1. Load model & components --------------------------------------
40
+ handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
41
  full_model = full_model.eval().cuda()
42
 
43
+ encoder = full_model.bert.encoder
44
+ embeddings = full_model.bert.embeddings
45
+ emb_ln = full_model.bert.emb_ln
46
+ emb_drop = full_model.bert.emb_drop
47
+ mlm_head = full_model.cls # prediction head
48
 
 
 
 
49
  # ------------------------------------------------------------------
50
+ # 2. Symbolic roles -------------------------------------------------
51
  SYMBOLIC_ROLES = [
52
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
53
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
 
56
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
57
  "<fabric>", "<jewelry>",
58
  ]
59
+ if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
60
+ for t in SYMBOLIC_ROLES):
61
+ sys.exit("❌ tokenizer missing special tokens")
62
 
63
+ # Quick helpers
64
+ MASK = tokenizer.mask_token
 
 
65
 
66
 
67
  # ------------------------------------------------------------------
68
+ # 3. Encoder-plus-MLM logic ---------------------------------------
69
+ def cosine(a,b):
70
+ return torch.nn.functional.cosine_similarity(a,b,dim=-1)
71
 
72
+ def pool_accuracy(ids, logits, pool_mask):
 
73
  """
74
+ ids : (S,) gold token ids
75
+ logits : (S,V) MLM logits
76
+ pool_mask : bool (S,) which tokens belong to the candidate pool
77
+ returns accuracy over masked positions only (if none, return 0)
78
  """
79
+ idx = pool_mask.nonzero(as_tuple=False).flatten()
80
+ if idx.numel()==0: return 0.0
81
+ preds = logits.argmax(-1)[idx]
82
+ gold = ids[idx]
83
+ return (preds==gold).float().mean().item()
84
+
85
+
86
+ @spaces.GPU
87
+ def encode_and_trace(text, selected_roles):
88
+ # if user unchecked everything we treat as "all"
89
+ if not selected_roles:
90
+ selected_roles = SYMBOLIC_ROLES
91
+ sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
92
+
93
+ # ---- Tokenise & encode once ----
94
+ batch = tokenizer(text, return_tensors="pt").to("cuda")
95
+ ids, att = batch.input_ids, batch.attention_mask
96
+ x = emb_drop(emb_ln(embeddings(ids)))
97
+ ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1])
98
+ enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H)
99
+
100
+ # ---- compute max-cos per token (F-0/F-1) ----
101
+ role_mat = embeddings.word_embeddings(
102
+ torch.tensor(sorted(sel_ids), device=enc.device)
103
+ ) # (R,H)
104
+ cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R)
105
+ maxcos, argrole = cos.max(-1) # (S,)
106
+
107
+ # ---- split tokens into High / Low half (F-2) ----
108
+ S = len(ids[0])
109
+ sort_idx = maxcos.argsort(descending=True)
110
+ hi_idx = sort_idx[: S//2]
111
+ lo_idx = sort_idx[S//2:]
112
+
113
+ # container for summary text
114
+ report_lines = []
115
+
116
+ # ---- pool builder helper (uses S-4…S-7) ----
117
+ def greedy_pool(token_indices, direction):
118
+ # direction=='hi' or 'lo'
119
+ pool = []
120
+ μ = None
121
+ for tix in token_indices:
122
+ t_vec = enc[tix]
123
+ # incremental update (S-7)
124
+ μ = t_vec if μ is None else (μ*len(pool) + t_vec)/(len(pool)+1)
125
+ pool.append(tix)
126
+ # mask all tokens *except* this pool
127
+ masked_ids = ids.clone()
128
+ keep = torch.tensor(pool, device=ids.device)
129
+ mask_mask = torch.ones_like(ids, dtype=torch.bool)
130
+ mask_mask[0, keep] = False
131
+ masked_ids[mask_mask] = tokenizer.mask_token_id
132
+ # run MLM
133
+ with torch.no_grad():
134
+ logits = mlm_head(full_model.bert.emb_dl(enc.unsqueeze(0))).logits[0]
135
+ acc = pool_accuracy(ids[0], logits, ~mask_mask[0])
136
+ report_lines.append(f"{direction}-pool size {len(pool)} → acc={acc:.2f}")
137
+ if acc >= 0.5: break
138
+ return pool, acc
139
+
140
+ pool_lo, acc_lo = greedy_pool(lo_idx, "low")
141
+ pool_hi, acc_hi = greedy_pool(hi_idx, "high")
142
+
143
+ # ---- package textual result ----
144
+ res_json = {
145
+ "Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
146
+ "Low accuracy": f"{acc_lo:.2f}",
147
+ "High-pool tokens":tokenizer.decode(ids[0, pool_hi]),
148
+ "High accuracy": f"{acc_hi:.2f}",
149
+ "Trace": "\n".join(report_lines)
150
+ }
151
+ # three outputs expected by UI
152
+ return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
153
 
154
 
155
  # ------------------------------------------------------------------
156
+ # 4. Gradio UI -----------------------------------------------------
 
157
  def build_interface():
158
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
159
+ gr.Markdown("## 🧠 Symbolic Encoder Inspector")
 
 
 
 
160
 
161
  with gr.Row():
162
  with gr.Column():
163
+ txt = gr.Textbox(label="Prompt", lines=3)
164
+ roles= gr.CheckboxGroup(
165
+ choices=SYMBOLIC_ROLES, label="Roles",
166
+ value=SYMBOLIC_ROLES # pre-checked
167
  )
168
+ btn = gr.Button("Run")
 
 
 
 
 
 
 
169
  with gr.Column():
170
+ out_json = gr.Textbox(label="Result JSON")
171
+ out_max = gr.Textbox(label="Max cos")
172
+ out_cnt = gr.Textbox(label="# roles")
 
 
 
 
 
173
 
174
+ btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
175
  return demo
176
 
177
 
178
+ if __name__=="__main__":
179
+ build_interface().launch()