AbstractPhil commited on
Commit
872b08b
Β·
verified Β·
1 Parent(s): 71b4610

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -39
app.py CHANGED
@@ -1,62 +1,68 @@
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
  # -----------------------------------------------
3
  # launch: python app.py
 
 
 
 
 
 
 
 
4
  import spaces
5
  import torch
6
- import gradio as gr
7
- import json
8
  from huggingface_hub import snapshot_download
 
9
  from bert_handler import create_handler_from_checkpoint
10
- from pathlib import Path
11
 
12
  # ------------------------------------------------------------------
13
- # 1. Download *once* and load locally -----------------------------
14
  # ------------------------------------------------------------------
15
- LOCAL_CKPT = snapshot_download(
16
- repo_id="AbstractPhil/bert-beatrix-2048",
 
 
 
17
  revision="main",
18
- local_dir="bert-beatrix-2048",
19
  local_dir_use_symlinks=False,
20
  )
21
 
 
22
  cfg_path = Path(LOCAL_CKPT) / "config.json"
23
- with open(cfg_path) as f:
24
  cfg = json.load(f)
25
 
26
  auto_map = cfg.get("auto_map", {})
27
  changed = False
28
  for k, v in auto_map.items():
29
- # v looks like "AbstractPhil/bert-beatrix-2048--modeling_hf_nomic_bert.…"
30
- if "--" in v:
31
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
32
  changed = True
33
 
34
  if changed:
35
  cfg["auto_map"] = auto_map
36
- with open(cfg_path, "w") as f:
37
  json.dump(cfg, f, indent=2)
38
- print("πŸ”§ Patched auto_map β†’ now points to local modules only")
39
 
40
- # also drop any *previously* imported remote modules in this session
41
- for name in list(sys.modules):
42
- if name.startswith("transformers_modules.AbstractPhil.bert-beatrix-2048"):
43
- del sys.modules[name]
44
 
45
  # ------------------------------------------------------------------
46
- # 1. normal load via BERTHandler ---------------------------------
47
  # ------------------------------------------------------------------
48
- from bert_handler import create_handler_from_checkpoint
49
  handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
50
  full_model = full_model.eval().cuda()
51
 
52
- # --- pull encoder & embeddings only --------------------------------
53
  encoder = full_model.bert.encoder
54
  embeddings = full_model.bert.embeddings
55
  emb_ln = full_model.bert.emb_ln
56
  emb_drop = full_model.bert.emb_drop
57
 
 
58
  # ------------------------------------------------------------------
59
- # 2. Symbolic token list ------------------------------------------
60
  # ------------------------------------------------------------------
61
  SYMBOLIC_ROLES = [
62
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
@@ -64,17 +70,18 @@ SYMBOLIC_ROLES = [
64
  "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
65
  "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
66
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
67
- "<fabric>", "<jewelry>"
68
  ]
69
 
70
- # Sanity-check: every role must be known by the tokenizer
71
- missing = [t for t in SYMBOLIC_ROLES
72
- if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
73
  if missing:
74
- raise RuntimeError(f"Tokenizer is missing special tokens: {missing}")
 
75
 
76
  # ------------------------------------------------------------------
77
- # 3. Encoder-only inference util ----------------------------------
78
  # ------------------------------------------------------------------
79
  @spaces.GPU
80
  def encode_and_trace(text: str, selected_roles: list[str]):
@@ -85,41 +92,60 @@ def encode_and_trace(text: str, selected_roles: list[str]):
85
  x = emb_drop(emb_ln(embeddings(ids)))
86
 
87
  ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
88
- enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
89
 
90
- want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
91
- keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device)
 
92
 
93
- found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want]
94
- if keep.any():
95
- vec = enc[0][keep].mean(0)
 
 
96
  norm = f"{vec.norm().item():.4f}"
97
  else:
98
  norm = "0.0000"
99
 
100
  return {
101
  "Symbolic Tokens": ", ".join(found) or "(none)",
102
- "Mean Norm": norm,
103
- "Token Count": int(keep.sum().item()),
104
  }
105
 
 
106
  # ------------------------------------------------------------------
107
- # 4. Gradio UI -----------------------------------------------------
108
  # ------------------------------------------------------------------
109
  def build_interface():
110
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
111
- gr.Markdown("## 🧠 Symbolic Encoder Inspector")
 
 
 
 
 
112
  with gr.Row():
113
  with gr.Column():
114
- txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
115
- chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles")
 
 
 
 
 
 
 
116
  btn = gr.Button("Encode & Trace")
117
  with gr.Column():
118
  out_tok = gr.Textbox(label="Symbolic Tokens Found")
119
  out_norm = gr.Textbox(label="Mean Norm")
120
  out_cnt = gr.Textbox(label="Token Count")
121
- btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
 
 
122
  return demo
123
 
 
124
  if __name__ == "__main__":
125
  build_interface().launch()
 
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
  # -----------------------------------------------
3
  # launch: python app.py
4
+ # (gradio UI appears at http://localhost:7860)
5
+
6
+ import json
7
+ import re
8
+ import sys
9
+ from pathlib import Path, PurePosixPath # ← PurePosixPath import added
10
+
11
+ import gradio as gr
12
  import spaces
13
  import torch
 
 
14
  from huggingface_hub import snapshot_download
15
+
16
  from bert_handler import create_handler_from_checkpoint
17
+
18
 
19
  # ------------------------------------------------------------------
20
+ # 0. Download & patch config.json --------------------------------
21
  # ------------------------------------------------------------------
22
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
23
+ LOCAL_CKPT = "bert-beatrix-2048" # cached dir name
24
+
25
+ snapshot_download(
26
+ repo_id=REPO_ID,
27
  revision="main",
28
+ local_dir=LOCAL_CKPT,
29
  local_dir_use_symlinks=False,
30
  )
31
 
32
+ # ── one-time patch: strip the β€œrepo--” prefix that confuses AutoModel ──
33
  cfg_path = Path(LOCAL_CKPT) / "config.json"
34
+ with cfg_path.open() as f:
35
  cfg = json.load(f)
36
 
37
  auto_map = cfg.get("auto_map", {})
38
  changed = False
39
  for k, v in auto_map.items():
40
+ if "--" in v: # v looks like "repo--module.Class"
 
41
  auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
42
  changed = True
43
 
44
  if changed:
45
  cfg["auto_map"] = auto_map
46
+ with cfg_path.open("w") as f:
47
  json.dump(cfg, f, indent=2)
48
+ print("πŸ› οΈ Patched config.json β†’ auto_map now points at local modules")
49
 
 
 
 
 
50
 
51
  # ------------------------------------------------------------------
52
+ # 1. Model / tokenizer -------------------------------------------
53
  # ------------------------------------------------------------------
 
54
  handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
55
  full_model = full_model.eval().cuda()
56
 
57
+ # Grab encoder + embedding stack only
58
  encoder = full_model.bert.encoder
59
  embeddings = full_model.bert.embeddings
60
  emb_ln = full_model.bert.emb_ln
61
  emb_drop = full_model.bert.emb_drop
62
 
63
+
64
  # ------------------------------------------------------------------
65
+ # 2. Symbolic token set ------------------------------------------
66
  # ------------------------------------------------------------------
67
  SYMBOLIC_ROLES = [
68
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
 
70
  "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
71
  "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
72
  "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
73
+ "<fabric>", "<jewelry>",
74
  ]
75
 
76
+ # quick sanity check
77
+ missing = [tok for tok in SYMBOLIC_ROLES
78
+ if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
79
  if missing:
80
+ sys.exit(f"❌ Tokenizer is missing {missing}")
81
+
82
 
83
  # ------------------------------------------------------------------
84
+ # 3. Encoder-only inference util ---------------------------------
85
  # ------------------------------------------------------------------
86
  @spaces.GPU
87
  def encode_and_trace(text: str, selected_roles: list[str]):
 
92
  x = emb_drop(emb_ln(embeddings(ids)))
93
 
94
  ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
95
+ enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
96
 
97
+ sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
98
+ flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
99
+ device=enc.device)
100
 
101
+ found = [tokenizer.convert_ids_to_tokens([tid])[0]
102
+ for tid in ids[0].tolist() if tid in sel_ids]
103
+
104
+ if flags.any():
105
+ vec = enc[0][flags].mean(0)
106
  norm = f"{vec.norm().item():.4f}"
107
  else:
108
  norm = "0.0000"
109
 
110
  return {
111
  "Symbolic Tokens": ", ".join(found) or "(none)",
112
+ "Embedding Norm": norm,
113
+ "Symbolic Token Count": int(flags.sum().item()),
114
  }
115
 
116
+
117
  # ------------------------------------------------------------------
118
+ # 4. Gradio UI ----------------------------------------------------
119
  # ------------------------------------------------------------------
120
  def build_interface():
121
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
122
+ gr.Markdown(
123
+ "## 🧠 Symbolic Encoder Inspector\n"
124
+ "Paste some text containing the special `<role>` tokens and "
125
+ "inspect their encoder representations."
126
+ )
127
+
128
  with gr.Row():
129
  with gr.Column():
130
+ txt = gr.Textbox(
131
+ label="Input with Symbolic Tokens",
132
+ placeholder="Example: A <subject> wearing <upper_body_clothing> …",
133
+ lines=3,
134
+ )
135
+ roles = gr.CheckboxGroup(
136
+ choices=SYMBOLIC_ROLES,
137
+ label="Trace these symbolic roles",
138
+ )
139
  btn = gr.Button("Encode & Trace")
140
  with gr.Column():
141
  out_tok = gr.Textbox(label="Symbolic Tokens Found")
142
  out_norm = gr.Textbox(label="Mean Norm")
143
  out_cnt = gr.Textbox(label="Token Count")
144
+
145
+ btn.click(encode_and_trace, [txt, roles], [out_tok, out_norm, out_cnt])
146
+
147
  return demo
148
 
149
+
150
  if __name__ == "__main__":
151
  build_interface().launch()