AbstractPhil commited on
Commit
415afa1
Β·
verified Β·
1 Parent(s): b60c583

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -1,13 +1,9 @@
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
- # -----------------------------------------------
3
- # launch: python app.py
4
- # (gradio UI appears at http://localhost:7860)
5
-
6
- import json
7
- import re
8
- import sys
9
- from pathlib import Path, PurePosixPath # ← PurePosixPath import added
10
 
 
 
11
  import gradio as gr
12
  import spaces
13
  import torch
@@ -17,10 +13,10 @@ from bert_handler import create_handler_from_checkpoint
17
 
18
 
19
  # ------------------------------------------------------------------
20
- # 0. Download & patch config.json --------------------------------
21
  # ------------------------------------------------------------------
22
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
23
- LOCAL_CKPT = "bert-beatrix-2048" # cached dir name
24
 
25
  snapshot_download(
26
  repo_id=REPO_ID,
@@ -29,32 +25,30 @@ snapshot_download(
29
  local_dir_use_symlinks=False,
30
  )
31
 
32
- # ── one-time patch: strip the β€œrepo--” prefix that confuses AutoModel ──
33
  cfg_path = Path(LOCAL_CKPT) / "config.json"
34
  with cfg_path.open() as f:
35
  cfg = json.load(f)
36
 
37
  auto_map = cfg.get("auto_map", {})
38
- changed = False
39
  for k, v in auto_map.items():
40
- if "--" in v: # v looks like "repo--module.Class"
41
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
42
- changed = True
43
 
44
- if changed:
45
  cfg["auto_map"] = auto_map
46
  with cfg_path.open("w") as f:
47
  json.dump(cfg, f, indent=2)
48
- print("πŸ› οΈ Patched config.json β†’ auto_map now points at local modules")
49
 
50
 
51
  # ------------------------------------------------------------------
52
- # 1. Model / tokenizer -------------------------------------------
53
  # ------------------------------------------------------------------
54
  handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
55
  full_model = full_model.eval().cuda()
56
 
57
- # Grab encoder + embedding stack only
58
  encoder = full_model.bert.encoder
59
  embeddings = full_model.bert.embeddings
60
  emb_ln = full_model.bert.emb_ln
@@ -62,7 +56,7 @@ emb_drop = full_model.bert.emb_drop
62
 
63
 
64
  # ------------------------------------------------------------------
65
- # 2. Symbolic token set ------------------------------------------
66
  # ------------------------------------------------------------------
67
  SYMBOLIC_ROLES = [
68
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
@@ -73,44 +67,47 @@ SYMBOLIC_ROLES = [
73
  "<fabric>", "<jewelry>",
74
  ]
75
 
76
- # quick sanity check
77
- missing = [tok for tok in SYMBOLIC_ROLES
78
- if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
79
  if missing:
80
  sys.exit(f"❌ Tokenizer is missing {missing}")
81
 
82
 
83
  # ------------------------------------------------------------------
84
- # 3. Encoder-only inference util ---------------------------------
85
  # ------------------------------------------------------------------
86
  @spaces.GPU
87
  def encode_and_trace(text: str, selected_roles: list[str]):
 
 
 
 
 
 
88
  with torch.no_grad():
89
  batch = tokenizer(text, return_tensors="pt").to("cuda")
90
- ids, mask = batch.input_ids, batch.attention_mask
91
 
92
  x = emb_drop(emb_ln(embeddings(ids)))
93
- ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
94
- enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
95
 
96
- sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
97
- flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
98
- device=enc.device)
99
 
100
- found_tokens = [tokenizer.convert_ids_to_tokens([tid])[0]
101
- for tid in ids[0].tolist() if tid in sel_ids]
102
- tokens_str = ", ".join(found_tokens) or "(none)"
103
 
104
- if flags.any():
105
- vec = enc[0][flags].mean(0)
106
- norm = f"{vec.norm().item():.4f}"
107
  else:
108
- norm = "0.0000"
109
-
110
- count = int(flags.sum().item())
111
- # >>> return *three* scalars, not one dict <<<
112
- return tokens_str, norm, count
113
 
 
 
114
 
115
 
116
  # ------------------------------------------------------------------
@@ -119,29 +116,23 @@ def encode_and_trace(text: str, selected_roles: list[str]):
119
  def build_interface():
120
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
121
  gr.Markdown(
122
- "## 🧠 Symbolic Encoder Inspector\n"
123
- "Paste some text containing the special `<role>` tokens and "
124
- "inspect their encoder representations."
125
  )
126
 
127
  with gr.Row():
128
  with gr.Column():
129
- txt = gr.Textbox(
130
- label="Input with Symbolic Tokens",
131
- placeholder="Example: A <subject> wearing <upper_body_clothing> …",
132
- lines=3,
133
- )
134
- roles = gr.CheckboxGroup(
135
- choices=SYMBOLIC_ROLES,
136
- label="Trace these symbolic roles",
137
- )
138
- btn = gr.Button("Encode & Trace")
139
  with gr.Column():
140
- out_tok = gr.Textbox(label="Symbolic Tokens Found")
141
- out_norm = gr.Textbox(label="Mean Norm")
142
- out_cnt = gr.Textbox(label="Token Count")
143
 
144
- btn.click(encode_and_trace, [txt, roles], [out_tok, out_norm, out_cnt])
145
 
146
  return demo
147
 
 
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
+ # ------------------------------------------------
3
+ # launch: python app.py β†’ http://localhost:7860
 
 
 
 
 
 
4
 
5
+ import json, re, sys
6
+ from pathlib import Path, PurePosixPath # ← PurePosixPath import
7
  import gradio as gr
8
  import spaces
9
  import torch
 
13
 
14
 
15
  # ------------------------------------------------------------------
16
+ # 0. Download & patch config.json ---------------------------------
17
  # ------------------------------------------------------------------
18
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
19
+ LOCAL_CKPT = "bert-beatrix-2048" # cache dir name
20
 
21
  snapshot_download(
22
  repo_id=REPO_ID,
 
25
  local_dir_use_symlinks=False,
26
  )
27
 
 
28
  cfg_path = Path(LOCAL_CKPT) / "config.json"
29
  with cfg_path.open() as f:
30
  cfg = json.load(f)
31
 
32
  auto_map = cfg.get("auto_map", {})
33
+ patched = False
34
  for k, v in auto_map.items():
35
+ if "--" in v: # strip repo--module.Class
36
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
37
+ patched = True
38
 
39
+ if patched:
40
  cfg["auto_map"] = auto_map
41
  with cfg_path.open("w") as f:
42
  json.dump(cfg, f, indent=2)
43
+ print("πŸ› οΈ Patched config.json β†’ auto_map now points to local modules")
44
 
45
 
46
  # ------------------------------------------------------------------
47
+ # 1. Load model / tokenizer ---------------------------------------
48
  # ------------------------------------------------------------------
49
  handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
50
  full_model = full_model.eval().cuda()
51
 
 
52
  encoder = full_model.bert.encoder
53
  embeddings = full_model.bert.embeddings
54
  emb_ln = full_model.bert.emb_ln
 
56
 
57
 
58
  # ------------------------------------------------------------------
59
+ # 2. Symbolic roles ------------------------------------------------
60
  # ------------------------------------------------------------------
61
  SYMBOLIC_ROLES = [
62
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
 
67
  "<fabric>", "<jewelry>",
68
  ]
69
 
70
+ missing = [t for t in SYMBOLIC_ROLES
71
+ if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
 
72
  if missing:
73
  sys.exit(f"❌ Tokenizer is missing {missing}")
74
 
75
 
76
  # ------------------------------------------------------------------
77
+ # 3. Encoder-only helper ------------------------------------------
78
  # ------------------------------------------------------------------
79
  @spaces.GPU
80
  def encode_and_trace(text: str, selected_roles: list[str]):
81
+ """
82
+ Returns **exactly three scalars** matching the 3 gradio outputs:
83
+ 1) tokens_str (comma-separated list found)
84
+ 2) norm_str (mean L2-norm of those embeddings)
85
+ 3) count_int (# tokens matched)
86
+ """
87
  with torch.no_grad():
88
  batch = tokenizer(text, return_tensors="pt").to("cuda")
89
+ ids, attn = batch.input_ids, batch.attention_mask
90
 
91
  x = emb_drop(emb_ln(embeddings(ids)))
92
+ ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
93
+ enc = encoder(x, attention_mask=ext) # (1, S, H)
94
 
95
+ role_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
96
+ mask = torch.tensor([tid in role_ids for tid in ids[0].tolist()],
97
+ device=enc.device, dtype=torch.bool)
98
 
99
+ found = [tokenizer.convert_ids_to_tokens([tid])[0]
100
+ for tid in ids[0].tolist() if tid in role_ids]
101
+ tokens_str = ", ".join(found) or "(none)"
102
 
103
+ if mask.any():
104
+ mean_vec = enc[0][mask].mean(0)
105
+ norm_str = f"{mean_vec.norm().item():.4f}"
106
  else:
107
+ norm_str = "0.0000"
 
 
 
 
108
 
109
+ count_int = int(mask.sum().item())
110
+ return tokens_str, norm_str, count_int # ← three outputs!
111
 
112
 
113
  # ------------------------------------------------------------------
 
116
  def build_interface():
117
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
118
  gr.Markdown(
119
+ "### 🧠 Symbolic Encoder Inspector\n"
120
+ "Paste text that includes the `<role>` tokens and inspect their "
121
+ "hidden-state statistics."
122
  )
123
 
124
  with gr.Row():
125
  with gr.Column():
126
+ txt = gr.Textbox(label="Input", lines=3,
127
+ placeholder="A <subject> wearing <upper_body_clothing> …")
128
+ chk = gr.CheckboxGroup(SYMBOLIC_ROLES, label="Roles to trace")
129
+ run = gr.Button("Encode & Trace")
 
 
 
 
 
 
130
  with gr.Column():
131
+ out_tok = gr.Textbox(label="Tokens found")
132
+ out_norm = gr.Textbox(label="Mean norm")
133
+ out_cnt = gr.Textbox(label="Token count")
134
 
135
+ run.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
136
 
137
  return demo
138