AbstractPhil commited on
Commit
da8b548
·
verified ·
1 Parent(s): 6c28416

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -118
app.py CHANGED
@@ -1,24 +1,21 @@
1
- # app.py – encoder-only demo for bert-beatrix-2048
 
2
  # launch: python app.py
3
-
4
- import json
5
- import sys
6
  from pathlib import Path, PurePosixPath
7
- from itertools import islice
8
 
9
- import gradio as gr
10
- import spaces
11
- import torch
12
- import torch.nn.functional as F
13
  from huggingface_hub import snapshot_download
14
 
15
  from bert_handler import create_handler_from_checkpoint
16
 
17
  # ------------------------------------------------------------------
18
- # 0. Download & patch config.json --------------------------------
19
  # ------------------------------------------------------------------
20
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
21
- LOCAL_CKPT = "bert-beatrix-2048"
22
 
23
  snapshot_download(
24
  repo_id=REPO_ID,
@@ -28,35 +25,32 @@ snapshot_download(
28
  )
29
 
30
  cfg_path = Path(LOCAL_CKPT) / "config.json"
31
- with cfg_path.open() as f:
32
- cfg = json.load(f)
33
-
34
- auto_map = cfg.get("auto_map", {})
35
- patched = False
36
  for k, v in auto_map.items():
37
- if "--" in v: # "repo--module.Class"
38
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
39
- patched = True
40
-
41
- if patched:
42
- cfg["auto_map"] = auto_map
43
- with cfg_path.open("w") as f:
44
- json.dump(cfg, f, indent=2)
45
- print("🛠️ Patched config.json → auto_map paths fixed")
46
 
47
  # ------------------------------------------------------------------
48
- # 1. Model / tokenizer -------------------------------------------
49
  # ------------------------------------------------------------------
50
  handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
51
  full_model = full_model.eval().cuda()
52
 
 
53
  encoder = full_model.bert.encoder
54
  embeddings = full_model.bert.embeddings
 
55
  emb_ln = full_model.bert.emb_ln
56
  emb_drop = full_model.bert.emb_drop
57
 
58
  # ------------------------------------------------------------------
59
- # 2. Symbolic token set ------------------------------------------
60
  # ------------------------------------------------------------------
61
  SYMBOLIC_ROLES = [
62
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
@@ -66,114 +60,108 @@ SYMBOLIC_ROLES = [
66
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
67
  "<fabric>", "<jewelry>",
68
  ]
69
-
70
- unk = tokenizer.unk_token_id
71
- missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
72
  if missing:
73
- sys.exit(f"❌ Tokenizer is missing {missing}")
74
 
75
- # ------------------------------------------------------------------
76
- # 3. helper: merge lowest + highest until 3 remain ----------------
77
- # ------------------------------------------------------------------
78
- def reduce_to_three(table):
79
- """
80
- table : list of dicts {role, token, score}
81
- repeatedly remove lowest and highest,
82
- replace with their average,
83
- until len(table)==3.
84
- """
85
- working = table[:]
86
- working.sort(key=lambda x: x["score"])
87
- while len(working) > 3:
88
- low = working.pop(0)
89
- high = working.pop(-1)
90
- merged = {
91
- "role": f"{high['role']}|{low['role']}",
92
- "token": f"{high['token']}/{low['token']}",
93
- "score": (high["score"] + low["score"]) / 2.0,
94
- }
95
- working.append(merged)
96
- working.sort(key=lambda x: x["score"])
97
- # highest first for display
98
- working.sort(key=lambda x: x["score"], reverse=True)
99
- return working
100
 
101
  # ------------------------------------------------------------------
102
- # 4. Encoder-only inference util ---------------------------------
103
- # ------------------------------------------------------------------
104
- @spaces.GPU
105
- def encode_and_trace(text: str, selected_roles: list[str]):
 
 
 
 
 
 
 
106
  with torch.no_grad():
107
- if not text.strip():
108
- return "(no input)","",""
109
-
110
- batch = tokenizer(text, return_tensors="pt").to("cuda")
111
- ids, attn = batch.input_ids, batch.attention_mask
112
-
113
- # encoder forward
114
- x = emb_drop(emb_ln(embeddings(ids)))
115
- ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
116
- hs = encoder(x, attention_mask=ext) # (1,S,H)
117
-
118
- # token-level embeddings (before LN) for similarity calc
119
- token_emb = embeddings(ids).squeeze(0) # (S,H)
120
-
121
- rows = []
122
- for role in selected_roles:
123
- rid = tokenizer.convert_tokens_to_ids(role)
124
- rvec = embeddings.word_embeddings.weight[rid] # (H,)
125
- # cosine similarity to every *input* token embedding
126
- sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
127
- best = torch.argmax(sims).item()
128
- rows.append({
129
- "role" : role,
130
- "token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
131
- "score": sims[best].item()
132
- })
133
-
134
- if not rows:
135
- return "(none selected)","",""
136
-
137
- final3 = reduce_to_three(rows)
138
- out_strs = [f"{r['role']} ↔ {r['token']} ({r['score']:+.2f})" for r in final3]
139
- # pad so we always return 3 strings
140
- while len(out_strs) < 3:
141
- out_strs.append("")
142
- return out_strs[0], out_strs[1], out_strs[2]
143
 
144
  # ------------------------------------------------------------------
145
- # 5. Gradio UI ----------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # ------------------------------------------------------------------
147
- def build_interface():
 
148
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
149
  gr.Markdown(
150
- "### 🧠 Symbolic Encoder Inspector\n"
151
- "Paste text with `<role>` tokens, pick roles to track, then we\n"
152
- "• compute role ↔ token cosine scores\n"
153
- "• iteratively merge low+high pairs until **3 composite buckets** remain."
154
  )
155
-
156
  with gr.Row():
157
  with gr.Column():
158
- txt = gr.Textbox(
159
- label="Input with Symbolic Tokens",
160
- placeholder="Example: A <subject> wearing <upper_body_clothing> …",
161
- lines=3,
162
- )
163
  roles = gr.CheckboxGroup(
164
- choices=SYMBOLIC_ROLES,
165
- label="Trace these symbolic roles",
 
166
  )
167
- btn = gr.Button("Encode & Merge")
168
  with gr.Column():
169
- cat1 = gr.Textbox(label="Category 1 (highest)")
170
- cat2 = gr.Textbox(label="Category 2")
171
- cat3 = gr.Textbox(label="Category 3 (lowest)")
172
-
173
- btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])
174
 
 
175
  return demo
176
 
177
-
178
  if __name__ == "__main__":
179
- build_interface().launch()
 
1
+ # app.py – encoder-only demo + pool-and-test prototype
2
+ # ----------------------------------------------------
3
  # launch: python app.py
4
+ # UI: http://localhost:7860
5
+ import json, re, sys, math
 
6
  from pathlib import Path, PurePosixPath
 
7
 
8
+ import torch, torch.nn.functional as F
9
+ import gradio as gr, spaces
 
 
10
  from huggingface_hub import snapshot_download
11
 
12
  from bert_handler import create_handler_from_checkpoint
13
 
14
  # ------------------------------------------------------------------
15
+ # 0. One-time patch of auto_map in config.json
16
  # ------------------------------------------------------------------
17
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
18
+ LOCAL_CKPT = "bert-beatrix-2048"
19
 
20
  snapshot_download(
21
  repo_id=REPO_ID,
 
25
  )
26
 
27
  cfg_path = Path(LOCAL_CKPT) / "config.json"
28
+ cfg = json.loads(cfg_path.read_text())
29
+ auto_map = cfg.get("auto_map", {})
30
+ changed = False
 
 
31
  for k, v in auto_map.items():
32
+ if "--" in v: # strip “repo--”
33
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
34
+ changed = True
35
+ if changed:
36
+ cfg_path.write_text(json.dumps(cfg, indent=2))
37
+ print("🛠️ Patched config.json → auto_map points to local modules")
 
 
 
38
 
39
  # ------------------------------------------------------------------
40
+ # 1. Load model + tokenizer with BERTHandler
41
  # ------------------------------------------------------------------
42
  handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
43
  full_model = full_model.eval().cuda()
44
 
45
+ # pull encoder & embedding stack
46
  encoder = full_model.bert.encoder
47
  embeddings = full_model.bert.embeddings
48
+ emb_weight = embeddings.word_embeddings.weight # <- correct tensor
49
  emb_ln = full_model.bert.emb_ln
50
  emb_drop = full_model.bert.emb_drop
51
 
52
  # ------------------------------------------------------------------
53
+ # 2. Symbolic roles
54
  # ------------------------------------------------------------------
55
  SYMBOLIC_ROLES = [
56
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
 
60
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
61
  "<fabric>", "<jewelry>",
62
  ]
63
+ missing = [t for t in SYMBOLIC_ROLES
64
+ if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
 
65
  if missing:
66
+ sys.exit(f"❌ Tokenizer missing {missing}")
67
 
68
+ MASK_ID = tokenizer.mask_token_id
69
+ MASK_TOK = tokenizer.mask_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # ------------------------------------------------------------------
72
+ # helpers -----------------------------------------------------------
73
+ def contextual_vectors(ids, mask):
74
+ """run through embedding→encoder, return (S,H) hidden states"""
75
+ x = emb_drop(emb_ln(embeddings(ids))) # (1,S,H)
76
+ ext = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
77
+ return encoder(x, attention_mask=ext).squeeze(0) # (S,H)
78
+
79
+ def pool_accuracy(ids, mask, pool_positions):
80
+ """mask positions in pool, predict, calc accuracy"""
81
+ masked = ids.clone()
82
+ masked[0, pool_positions] = MASK_ID
83
  with torch.no_grad():
84
+ logits = full_model(masked, attention_mask=mask).logits[0]
85
+ preds = logits.argmax(-1)
86
+ gold = ids.squeeze(0)
87
+ correct = (preds[pool_positions] == gold[pool_positions]).sum().item()
88
+ return correct / len(pool_positions) if pool_positions else 0.0
89
+
90
+ # cosine utility
91
+ def cos(a, b): return F.cosine_similarity(a, b, dim=-1, eps=1e-8).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # ------------------------------------------------------------------
94
+ # 3. Core routine ---------------------------------------------------
95
+ @spaces.GPU
96
+ def encode_and_trace(text: str, picked_roles: list[str]):
97
+ # -------- tokenise ----------
98
+ batch = tokenizer(text, return_tensors="pt").to("cuda")
99
+ ids, attn = batch.input_ids, batch.attention_mask
100
+ hid = contextual_vectors(ids, attn) # (S,H)
101
+
102
+ # -------- decide which roles we analyse ----------
103
+ present = {tid: pos for pos, tid in enumerate(ids[0].tolist())
104
+ if tid in {tokenizer.convert_tokens_to_ids(r) for r in SYMBOLIC_ROLES}}
105
+ if picked_roles:
106
+ present = {tid: pos for tid, pos in present.items()
107
+ if tokenizer.convert_ids_to_tokens([tid])[0] in picked_roles}
108
+ if not present:
109
+ return "No symbolic tokens in sentence", "", ""
110
+
111
+ # -------- similarity scores ----------
112
+ sims = []
113
+ for tid, pos in present.items():
114
+ rvec = emb_weight[tid] # static embedding
115
+ cvec = hid[pos] # contextual
116
+ sims.append((cos(cvec, rvec), tid, pos))
117
+ sims.sort() # low → high
118
+ # pools: bottom-2, top-2 (expand later)
119
+ low_pool, high_pool = sims[:2], sims[-2:]
120
+ accepted = []
121
+
122
+ for grow in range(1 + math.ceil(len(sims)/2)): # ≤26 shots
123
+ for tag, pool in [("low", low_pool), ("high", high_pool)]:
124
+ pool_pos = [p for _,_,p in pool]
125
+ acc = pool_accuracy(ids, attn, pool_pos)
126
+ if acc >= 0.5: # category accepted
127
+ roles = [tokenizer.convert_ids_to_tokens([tid])[0] for _,tid,_ in pool]
128
+ accepted.append(f"{tag}:{roles} (acc {acc:.2f})")
129
+ if accepted: break # stop once something passed
130
+ # grow pools by two (if any left)
131
+ next_lo = sims[2+grow*2 : 4+grow*2]
132
+ next_hi = sims[-4-grow*2 : -2-grow*2] if 4+grow*2 <= len(sims) else []
133
+ low_pool += next_lo
134
+ high_pool += next_hi
135
+
136
+ if not accepted:
137
+ accepted = ["(none hit 50 %)"]
138
+
139
+ return ", ".join(accepted), f"{len(present)} roles analysed", f"{text[:80]}…"
140
+
141
  # ------------------------------------------------------------------
142
+ # 4. UI -------------------------------------------------------------
143
+ def build_ui():
144
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
145
  gr.Markdown(
146
+ "## 🧠 Symbolic Encoder Inspector \n"
147
+ "Select roles, paste text, and watch the pool-and-test prototype work."
 
 
148
  )
 
149
  with gr.Row():
150
  with gr.Column():
151
+ txt = gr.Textbox(lines=3, label="Input")
 
 
 
 
152
  roles = gr.CheckboxGroup(
153
+ SYMBOLIC_ROLES,
154
+ value=SYMBOLIC_ROLES,
155
+ label="Roles to consider (else all present)"
156
  )
157
+ btn = gr.Button("Run")
158
  with gr.Column():
159
+ out_cat = gr.Textbox(label="Accepted categories")
160
+ out_info= gr.Textbox(label="Debug")
161
+ out_excerpt = gr.Textbox(label="Excerpt")
 
 
162
 
163
+ btn.click(encode_and_trace, [txt, roles], [out_cat, out_info, out_excerpt])
164
  return demo
165
 
 
166
  if __name__ == "__main__":
167
+ build_ui().launch()