AbstractPhil commited on
Commit
0a14990
·
verified ·
1 Parent(s): da8b548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -128
app.py CHANGED
@@ -1,57 +1,45 @@
1
- # app.py – encoder-only demo + pool-and-test prototype
2
- # ----------------------------------------------------
3
- # launch: python app.py
4
- # UI: http://localhost:7860
5
- import json, re, sys, math
6
  from pathlib import Path, PurePosixPath
7
 
8
- import torch, torch.nn.functional as F
9
- import gradio as gr, spaces
 
10
  from huggingface_hub import snapshot_download
11
-
12
  from bert_handler import create_handler_from_checkpoint
13
 
14
  # ------------------------------------------------------------------
15
- # 0. One-time patch of auto_map in config.json
16
- # ------------------------------------------------------------------
17
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
18
- LOCAL_CKPT = "bert-beatrix-2048"
19
-
20
- snapshot_download(
21
- repo_id=REPO_ID,
22
- revision="main",
23
- local_dir=LOCAL_CKPT,
24
- local_dir_use_symlinks=False,
25
- )
26
-
27
- cfg_path = Path(LOCAL_CKPT) / "config.json"
28
- cfg = json.loads(cfg_path.read_text())
29
- auto_map = cfg.get("auto_map", {})
30
- changed = False
31
- for k, v in auto_map.items():
32
- if "--" in v: # strip “repo--”
33
- auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
34
- changed = True
35
- if changed:
36
- cfg_path.write_text(json.dumps(cfg, indent=2))
37
- print("🛠️ Patched config.json → auto_map points to local modules")
38
 
39
  # ------------------------------------------------------------------
40
- # 1. Load model + tokenizer with BERTHandler
41
- # ------------------------------------------------------------------
42
- handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
43
  full_model = full_model.eval().cuda()
44
 
45
- # pull encoder & embedding stack
46
- encoder = full_model.bert.encoder
47
- embeddings = full_model.bert.embeddings
48
- emb_weight = embeddings.word_embeddings.weight # <- correct tensor
49
- emb_ln = full_model.bert.emb_ln
50
- emb_drop = full_model.bert.emb_drop
51
 
52
  # ------------------------------------------------------------------
53
- # 2. Symbolic roles
54
- # ------------------------------------------------------------------
55
  SYMBOLIC_ROLES = [
56
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
57
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
@@ -60,108 +48,96 @@ SYMBOLIC_ROLES = [
60
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
61
  "<fabric>", "<jewelry>",
62
  ]
63
- missing = [t for t in SYMBOLIC_ROLES
64
- if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
65
- if missing:
66
- sys.exit(f"❌ Tokenizer missing {missing}")
67
-
68
- MASK_ID = tokenizer.mask_token_id
69
- MASK_TOK = tokenizer.mask_token
70
 
71
  # ------------------------------------------------------------------
72
- # helpers -----------------------------------------------------------
73
- def contextual_vectors(ids, mask):
74
- """run through embedding→encoder, return (S,H) hidden states"""
75
- x = emb_drop(emb_ln(embeddings(ids))) # (1,S,H)
76
- ext = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
77
- return encoder(x, attention_mask=ext).squeeze(0) # (S,H)
78
-
79
- def pool_accuracy(ids, mask, pool_positions):
80
- """mask positions in pool, predict, calc accuracy"""
81
- masked = ids.clone()
82
- masked[0, pool_positions] = MASK_ID
83
- with torch.no_grad():
84
- logits = full_model(masked, attention_mask=mask).logits[0]
85
- preds = logits.argmax(-1)
86
- gold = ids.squeeze(0)
87
- correct = (preds[pool_positions] == gold[pool_positions]).sum().item()
88
- return correct / len(pool_positions) if pool_positions else 0.0
89
 
90
- # cosine utility
91
- def cos(a, b): return F.cosine_similarity(a, b, dim=-1, eps=1e-8).item()
 
92
 
93
- # ------------------------------------------------------------------
94
- # 3. Core routine ---------------------------------------------------
95
- @spaces.GPU
96
- def encode_and_trace(text: str, picked_roles: list[str]):
97
- # -------- tokenise ----------
98
- batch = tokenizer(text, return_tensors="pt").to("cuda")
99
- ids, attn = batch.input_ids, batch.attention_mask
100
- hid = contextual_vectors(ids, attn) # (S,H)
101
-
102
- # -------- decide which roles we analyse ----------
103
- present = {tid: pos for pos, tid in enumerate(ids[0].tolist())
104
- if tid in {tokenizer.convert_tokens_to_ids(r) for r in SYMBOLIC_ROLES}}
105
- if picked_roles:
106
- present = {tid: pos for tid, pos in present.items()
107
- if tokenizer.convert_ids_to_tokens([tid])[0] in picked_roles}
108
- if not present:
109
- return "No symbolic tokens in sentence", "", ""
110
-
111
- # -------- similarity scores ----------
112
- sims = []
113
- for tid, pos in present.items():
114
- rvec = emb_weight[tid] # static embedding
115
- cvec = hid[pos] # contextual
116
- sims.append((cos(cvec, rvec), tid, pos))
117
- sims.sort() # low → high
118
- # pools: bottom-2, top-2 (expand later)
119
- low_pool, high_pool = sims[:2], sims[-2:]
120
- accepted = []
121
-
122
- for grow in range(1 + math.ceil(len(sims)/2)): # ≤26 shots
123
- for tag, pool in [("low", low_pool), ("high", high_pool)]:
124
- pool_pos = [p for _,_,p in pool]
125
- acc = pool_accuracy(ids, attn, pool_pos)
126
- if acc >= 0.5: # category accepted
127
- roles = [tokenizer.convert_ids_to_tokens([tid])[0] for _,tid,_ in pool]
128
- accepted.append(f"{tag}:{roles} (acc {acc:.2f})")
129
- if accepted: break # stop once something passed
130
- # grow pools by two (if any left)
131
- next_lo = sims[2+grow*2 : 4+grow*2]
132
- next_hi = sims[-4-grow*2 : -2-grow*2] if 4+grow*2 <= len(sims) else []
133
- low_pool += next_lo
134
- high_pool += next_hi
135
-
136
- if not accepted:
137
- accepted = ["(none hit 50 %)"]
138
-
139
- return ", ".join(accepted), f"{len(present)} roles analysed", f"{text[:80]}…"
140
 
141
  # ------------------------------------------------------------------
142
- # 4. UI -------------------------------------------------------------
143
- def build_ui():
144
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
145
  gr.Markdown(
146
  "## 🧠 Symbolic Encoder Inspector \n"
147
- "Select roles, paste text, and watch the pool-and-test prototype work."
 
148
  )
149
  with gr.Row():
150
  with gr.Column():
151
- txt = gr.Textbox(lines=3, label="Input")
 
 
 
 
152
  roles = gr.CheckboxGroup(
153
- SYMBOLIC_ROLES,
154
- value=SYMBOLIC_ROLES,
155
- label="Roles to consider (else all present)"
156
  )
157
- btn = gr.Button("Run")
158
  with gr.Column():
159
- out_cat = gr.Textbox(label="Accepted categories")
160
- out_info= gr.Textbox(label="Debug")
161
- out_excerpt = gr.Textbox(label="Excerpt")
162
 
163
- btn.click(encode_and_trace, [txt, roles], [out_cat, out_info, out_excerpt])
164
  return demo
165
 
166
  if __name__ == "__main__":
167
- build_ui().launch()
 
1
+ # app.py – encoder-only + masking accuracy demo for bert-beatrix-2048
2
+ # -----------------------------------------------------------------
3
+ # launch: python app.py (UI at http://localhost:7860)
4
+
5
+ import json, re, sys
6
  from pathlib import Path, PurePosixPath
7
 
8
+ import gradio as gr
9
+ import spaces
10
+ import torch
11
  from huggingface_hub import snapshot_download
 
12
  from bert_handler import create_handler_from_checkpoint
13
 
14
  # ------------------------------------------------------------------
15
+ # 0. download repo + patch auto_map --------------------------------
16
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
17
+ LOCAL_CK = "bert-beatrix-2048"
18
+ snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_CK, local_dir_use_symlinks=False)
19
+
20
+ cfg_p = Path(LOCAL_CK) / "config.json"
21
+ with cfg_p.open() as f:
22
+ cfg = json.load(f)
23
+ for k, v in cfg.get("auto_map", {}).items():
24
+ if "--" in v:
25
+ cfg["auto_map"][k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
26
+ with cfg_p.open("w") as f:
27
+ json.dump(cfg, f, indent=2)
 
 
 
 
 
 
 
 
 
 
28
 
29
  # ------------------------------------------------------------------
30
+ # 1. load model / tokenizer ---------------------------------------
31
+ handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CK)
 
32
  full_model = full_model.eval().cuda()
33
 
34
+ encoder = full_model.bert.encoder
35
+ embeddings = full_model.bert.embeddings
36
+ emb_ln = full_model.bert.emb_ln
37
+ emb_drop = full_model.bert.emb_drop
38
+
39
+ MASK = tokenizer.mask_token or "[MASK]"
40
 
41
  # ------------------------------------------------------------------
42
+ # 2. symbolic role list -------------------------------------------
 
43
  SYMBOLIC_ROLES = [
44
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
45
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
 
48
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
49
  "<fabric>", "<jewelry>",
50
  ]
51
+ miss = [t for t in SYMBOLIC_ROLES
52
+ if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
53
+ if miss:
54
+ sys.exit(f"❌ Tokenizer missing {miss}")
 
 
 
55
 
56
  # ------------------------------------------------------------------
57
+ # 3. inference util ----------------------------------------------
58
+ @spaces.GPU
59
+ def encode_and_trace(text: str, selected_roles: list[str]):
60
+ # ----- 3-A. build masked version & encode original --------------
61
+ sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # tokenised “plain” text
64
+ plain = tokenizer(text, return_tensors="pt").to("cuda")
65
+ ids_plain = plain.input_ids
66
 
67
+ # make masked string (regex to avoid partial hits)
68
+ masked_txt = text
69
+ for tok in selected_roles:
70
+ masked_txt = re.sub(re.escape(tok), MASK, masked_txt)
71
+
72
+ masked = tokenizer(masked_txt, return_tensors="pt").to("cuda")
73
+ ids_masked = masked.input_ids
74
+
75
+ # ----- 3-B. run model on masked text ----------------------------
76
+ with torch.no_grad():
77
+ logits = full_model(**masked).logits[0] # (S, V)
78
+ preds = logits.argmax(-1) # (S,)
79
+
80
+ # ----- 3-C. gather stats per masked role ------------------------
81
+ found_tokens, correct = [], 0
82
+ role_flags = []
83
+ for i, (orig_id, pred_id) in enumerate(zip(ids_plain[0], preds)):
84
+ if orig_id.item() in sel_ids and ids_masked[0, i].item() == tokenizer.mask_token_id:
85
+ found_tokens.append(tokenizer.convert_ids_to_tokens([orig_id])[0])
86
+ correct += int(orig_id.item() == pred_id.item())
87
+ role_flags.append(i)
88
+
89
+ total = len(role_flags)
90
+ acc = correct / total if total else 0.0
91
+
92
+ # ----- 3-D. encoder rep pooling for *all* selected roles --------
93
+ with torch.no_grad():
94
+ # embeddings -> normed reps
95
+ x = emb_drop(emb_ln(embeddings(ids_plain)))
96
+ attn = full_model.bert.get_extended_attention_mask(
97
+ plain.attention_mask, x.shape[:-1]
98
+ )
99
+ enc = encoder(x, attention_mask=attn) # (1,S,H)
100
+ mask_vec = torch.tensor(
101
+ [tid in sel_ids for tid in ids_plain[0].tolist()], device=enc.device
102
+ )
103
+ if mask_vec.any():
104
+ pooled = enc[0][mask_vec].mean(0)
105
+ norm = f"{pooled.norm().item():.4f}"
106
+ else:
107
+ norm = "0.0000"
108
+
109
+ tokens_str = ", ".join(found_tokens) or "(none)"
110
+ return tokens_str, norm, f"{acc*100:.1f}%"
 
 
 
111
 
112
  # ------------------------------------------------------------------
113
+ # 4. gradio UI ----------------------------------------------------
114
+ def app():
115
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
116
  gr.Markdown(
117
  "## 🧠 Symbolic Encoder Inspector \n"
118
+ "1. Model side: we *mask* every chosen role token, run the LM, and report how often it recovers the original. \n"
119
+ "2. Encoder side: we also pool hidden-state vectors for those roles and give their mean L2-norm."
120
  )
121
  with gr.Row():
122
  with gr.Column():
123
+ txt = gr.Textbox(
124
+ label="Input with Symbolic Tokens",
125
+ lines=3,
126
+ placeholder="Example: A <subject> wearing <upper_body_clothing> …",
127
+ )
128
  roles = gr.CheckboxGroup(
129
+ choices=SYMBOLIC_ROLES,
130
+ value=SYMBOLIC_ROLES, # <- all pre-selected
131
+ label="Roles to mask & trace",
132
  )
133
+ run = gr.Button("Run")
134
  with gr.Column():
135
+ o_tok = gr.Textbox(label="Masked-role tokens found")
136
+ o_norm = gr.Textbox(label="Mean hidden-state L2-norm")
137
+ o_acc = gr.Textbox(label="Recovery accuracy")
138
 
139
+ run.click(encode_and_trace, [txt, roles], [o_tok, o_norm, o_acc])
140
  return demo
141
 
142
  if __name__ == "__main__":
143
+ app().launch()