AbstractPhil commited on
Commit
81a0ae4
·
verified ·
1 Parent(s): 096fe3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -75
app.py CHANGED
@@ -1,34 +1,33 @@
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
  # -----------------------------------------------
3
  # launch: python app.py
4
- # (gradio UI appears at http://localhost:7860)
5
-
6
  import torch
7
  import gradio as gr
8
- import spaces
9
  from bert_handler import create_handler_from_checkpoint
10
 
11
  # ------------------------------------------------------------------
12
- # 1. Model / tokenizer -------------------------------------------------
13
  # ------------------------------------------------------------------
14
- #
15
- # • We load one repo *once*, via its canonical name.
16
- # • BERTHandler handles the VRAM-safe cleanup & guarantees that the
17
- # tokenizer already contains all special tokens saved in the checkpoint.
18
-
19
- REPO_ID = "AbstractPhil/bert-beatrix-2048"
20
-
21
- handler, full_model, tokenizer = create_handler_from_checkpoint(REPO_ID)
22
  full_model = full_model.eval().cuda()
23
 
24
- # Grab the encoder + embedding stack only
25
  encoder = full_model.bert.encoder
26
  embeddings = full_model.bert.embeddings
27
  emb_ln = full_model.bert.emb_ln
28
  emb_drop = full_model.bert.emb_drop
29
 
30
  # ------------------------------------------------------------------
31
- # 2. Symbolic token set -------------------------------------------
32
  # ------------------------------------------------------------------
33
  SYMBOLIC_ROLES = [
34
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
@@ -39,55 +38,40 @@ SYMBOLIC_ROLES = [
39
  "<fabric>", "<jewelry>"
40
  ]
41
 
42
- # Quick sanity check should *never* be unk
43
- missing = [tok for tok in SYMBOLIC_ROLES
44
- if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
45
  if missing:
46
- raise RuntimeError(f"Tokenizer is missing {len(missing)} special tokens: {missing}")
47
 
48
  # ------------------------------------------------------------------
49
  # 3. Encoder-only inference util ----------------------------------
50
  # ------------------------------------------------------------------
51
  @spaces.GPU
52
  def encode_and_trace(text: str, selected_roles: list[str]):
53
- """
54
- • encodes `text`
55
- • pulls out the hidden states for any of the `selected_roles`
56
- • returns some quick stats so we can verify everything’s wired up
57
- """
58
  with torch.no_grad():
59
  batch = tokenizer(text, return_tensors="pt").to("cuda")
60
- inp_ids, attn_mask = batch.input_ids, batch.attention_mask
61
-
62
- # --- embedding + LayerNorm/dropout ---
63
- x = embeddings(inp_ids)
64
- x = emb_drop(emb_ln(x))
65
 
66
- # --- proper *additive* attention mask ---
67
- ext_mask = full_model.bert.get_extended_attention_mask(
68
- attn_mask, x.shape[:-1]
69
- )
70
 
71
- encoded = encoder(x, attention_mask=ext_mask) # (B, S, H)
 
72
 
73
- # --- pick out the positions that match selected_roles ---
74
- sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
75
- ids_list = inp_ids.squeeze(0).tolist() # python ints
76
- keep_mask = torch.tensor([tid in sel_ids for tid in ids_list],
77
- device=encoded.device)
78
 
79
- tokens_found = [tokenizer.convert_ids_to_tokens([tid])[0]
80
- for tid in ids_list if tid in sel_ids]
81
- if keep_mask.any():
82
- repr_vec = encoded.squeeze(0)[keep_mask].mean(0)
83
- norm_val = f"{repr_vec.norm().item():.4f}"
84
  else:
85
- norm_val = "0.0000"
86
 
87
  return {
88
- "Symbolic Tokens": ", ".join(tokens_found) or "(none)",
89
- "Embedding Norm": norm_val,
90
- "Symbolic Token Count": int(keep_mask.sum().item()),
91
  }
92
 
93
  # ------------------------------------------------------------------
@@ -95,37 +79,18 @@ def encode_and_trace(text: str, selected_roles: list[str]):
95
  # ------------------------------------------------------------------
96
  def build_interface():
97
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
98
-
99
- gr.Markdown("## 🧠 Symbolic Encoder Inspector\n"
100
- "Paste some text containing the special `<role>` tokens and "
101
- "inspect their encoder representations.")
102
-
103
  with gr.Row():
104
  with gr.Column():
105
- input_text = gr.Textbox(
106
- label="Input with Symbolic Tokens",
107
- placeholder="Example: A <subject> wearing <upper_body_clothing> …",
108
- lines=3,
109
- )
110
- role_selector = gr.CheckboxGroup(
111
- choices=SYMBOLIC_ROLES,
112
- label="Trace these symbolic roles"
113
- )
114
- run_btn = gr.Button("Encode & Trace")
115
  with gr.Column():
116
- out_tokens = gr.Textbox(label="Symbolic Tokens Found")
117
- out_norm = gr.Textbox(label="Mean Norm")
118
- out_count = gr.Textbox(label="Token Count")
119
-
120
- run_btn.click(
121
- fn=encode_and_trace,
122
- inputs=[input_text, role_selector],
123
- outputs=[out_tokens, out_norm, out_count],
124
- )
125
-
126
  return demo
127
 
128
-
129
  if __name__ == "__main__":
130
- demo = build_interface()
131
- demo.launch()
 
1
  # app.py – encoder-only demo for bert-beatrix-2048
2
  # -----------------------------------------------
3
  # launch: python app.py
4
+ import spaces
 
5
  import torch
6
  import gradio as gr
7
+ from huggingface_hub import snapshot_download
8
  from bert_handler import create_handler_from_checkpoint
9
 
10
  # ------------------------------------------------------------------
11
+ # 1. Download *once* and load locally -----------------------------
12
  # ------------------------------------------------------------------
13
+ LOCAL_CKPT = snapshot_download(
14
+ repo_id="AbstractPhil/bert-beatrix-2048",
15
+ revision="main",
16
+ local_dir="bert-beatrix-2048",
17
+ local_dir_use_symlinks=False
18
+ )
19
+
20
+ handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
21
  full_model = full_model.eval().cuda()
22
 
23
+ # --- pull encoder & embeddings only --------------------------------
24
  encoder = full_model.bert.encoder
25
  embeddings = full_model.bert.embeddings
26
  emb_ln = full_model.bert.emb_ln
27
  emb_drop = full_model.bert.emb_drop
28
 
29
  # ------------------------------------------------------------------
30
+ # 2. Symbolic token list ------------------------------------------
31
  # ------------------------------------------------------------------
32
  SYMBOLIC_ROLES = [
33
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
 
38
  "<fabric>", "<jewelry>"
39
  ]
40
 
41
+ # Sanity-check: every role must be known by the tokenizer
42
+ missing = [t for t in SYMBOLIC_ROLES
43
+ if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
44
  if missing:
45
+ raise RuntimeError(f"Tokenizer is missing special tokens: {missing}")
46
 
47
  # ------------------------------------------------------------------
48
  # 3. Encoder-only inference util ----------------------------------
49
  # ------------------------------------------------------------------
50
  @spaces.GPU
51
  def encode_and_trace(text: str, selected_roles: list[str]):
 
 
 
 
 
52
  with torch.no_grad():
53
  batch = tokenizer(text, return_tensors="pt").to("cuda")
54
+ ids, mask = batch.input_ids, batch.attention_mask
 
 
 
 
55
 
56
+ x = emb_drop(emb_ln(embeddings(ids)))
 
 
 
57
 
58
+ ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
59
+ enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
60
 
61
+ want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
62
+ keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device)
 
 
 
63
 
64
+ found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want]
65
+ if keep.any():
66
+ vec = enc[0][keep].mean(0)
67
+ norm = f"{vec.norm().item():.4f}"
 
68
  else:
69
+ norm = "0.0000"
70
 
71
  return {
72
+ "Symbolic Tokens": ", ".join(found) or "(none)",
73
+ "Mean Norm": norm,
74
+ "Token Count": int(keep.sum().item()),
75
  }
76
 
77
  # ------------------------------------------------------------------
 
79
  # ------------------------------------------------------------------
80
  def build_interface():
81
  with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
82
+ gr.Markdown("## 🧠 Symbolic Encoder Inspector")
 
 
 
 
83
  with gr.Row():
84
  with gr.Column():
85
+ txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
86
+ chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles")
87
+ btn = gr.Button("Encode & Trace")
 
 
 
 
 
 
 
88
  with gr.Column():
89
+ out_tok = gr.Textbox(label="Symbolic Tokens Found")
90
+ out_norm = gr.Textbox(label="Mean Norm")
91
+ out_cnt = gr.Textbox(label="Token Count")
92
+ btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
 
 
 
 
 
 
93
  return demo
94
 
 
95
  if __name__ == "__main__":
96
+ build_interface().launch()