AbstractPhil commited on
Commit
096fe3a
·
verified ·
1 Parent(s): ec302cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -54
app.py CHANGED
@@ -1,55 +1,36 @@
1
- # Updating the app to use only the encoder from the model, ensuring symbolic support
 
 
 
2
 
3
- import spaces
4
- from bert_handler import create_handler_from_checkpoint
5
  import torch
6
  import gradio as gr
7
- import re
8
- from pathlib import Path
9
- from huggingface_hub import snapshot_download
10
-
11
- # Load checkpoint using BERTHandler (loads tokenizer and full model)
12
- checkpoint_path = snapshot_download(
13
- repo_id="AbstractPhil/bert-beatrix-2048",
14
- revision="main",
15
- local_dir="bert-beatrix-2048",
16
- local_dir_use_symlinks=False
17
- )
18
- handler, model, tokenizer = create_handler_from_checkpoint(checkpoint_path)
19
- model = model.eval().cuda()
20
-
21
- # Extract encoder only (NomicBertModel -> encoder)
22
- encoder = model.bert.encoder
23
- embeddings = model.bert.embeddings
24
- emb_ln = model.bert.emb_ln
25
- emb_drop = model.bert.emb_drop
26
 
27
- @spaces.GPU
28
- def encode_and_predict(text: str, selected_roles: list[str]):
29
- with torch.no_grad():
30
- inputs = tokenizer(text, return_tensors="pt").to("cuda")
31
- input_ids = inputs.input_ids
32
- attention_mask = inputs.attention_mask
 
33
 
34
- # Run embedding + encoder pipeline
35
- x = embeddings(input_ids)
36
- x = emb_ln(x)
37
- x = emb_drop(x)
38
- encoded = encoder(x, attention_mask=attention_mask.bool())
39
 
40
- symbolic_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in selected_roles]
41
- symbolic_mask = torch.isin(input_ids, torch.tensor(symbolic_ids, device=input_ids.device))
42
 
43
- masked_tokens = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in input_ids[0] if tid in symbolic_ids]
44
- role_reprs = encoded[symbolic_mask].mean(dim=0) if symbolic_mask.any() else torch.zeros_like(encoded[0, 0])
 
 
 
45
 
46
- return {
47
- "Symbolic Tokens": masked_tokens,
48
- "Embedding Norm": f"{role_reprs.norm().item():.4f}",
49
- "Symbolic Token Count": symbolic_mask.sum().item(),
50
- }
51
-
52
- symbolic_roles = [
53
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
54
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
55
  "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
@@ -58,26 +39,93 @@ symbolic_roles = [
58
  "<fabric>", "<jewelry>"
59
  ]
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def build_interface():
62
- with gr.Blocks() as demo:
63
- gr.Markdown("## 🧠 Symbolic Encoder Inspector")
 
 
 
 
64
  with gr.Row():
65
  with gr.Column():
66
- input_text = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
67
- selected_roles = gr.CheckboxGroup(
68
- choices=symbolic_roles,
69
- label="Which symbolic tokens should be traced?"
 
 
 
 
70
  )
71
  run_btn = gr.Button("Encode & Trace")
72
  with gr.Column():
73
- symbolic_tokens = gr.Textbox(label="Symbolic Tokens Found")
74
- embedding_norm = gr.Textbox(label="Mean Norm of Symbolic Embeddings")
75
- token_count = gr.Textbox(label="Count of Symbolic Tokens")
76
 
77
- run_btn.click(fn=encode_and_predict, inputs=[input_text, selected_roles], outputs=[symbolic_tokens, embedding_norm, token_count])
 
 
 
 
78
 
79
  return demo
80
 
 
81
  if __name__ == "__main__":
82
  demo = build_interface()
83
  demo.launch()
 
1
+ # app.py encoder-only demo for bert-beatrix-2048
2
+ # -----------------------------------------------
3
+ # launch: python app.py
4
+ # (gradio UI appears at http://localhost:7860)
5
 
 
 
6
  import torch
7
  import gradio as gr
8
+ import spaces
9
+ from bert_handler import create_handler_from_checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # ------------------------------------------------------------------
12
+ # 1. Model / tokenizer -------------------------------------------------
13
+ # ------------------------------------------------------------------
14
+ #
15
+ # We load one repo *once*, via its canonical name.
16
+ # BERTHandler handles the VRAM-safe cleanup & guarantees that the
17
+ # tokenizer already contains all special tokens saved in the checkpoint.
18
 
19
+ REPO_ID = "AbstractPhil/bert-beatrix-2048"
 
 
 
 
20
 
21
+ handler, full_model, tokenizer = create_handler_from_checkpoint(REPO_ID)
22
+ full_model = full_model.eval().cuda()
23
 
24
+ # Grab the encoder + embedding stack only
25
+ encoder = full_model.bert.encoder
26
+ embeddings = full_model.bert.embeddings
27
+ emb_ln = full_model.bert.emb_ln
28
+ emb_drop = full_model.bert.emb_drop
29
 
30
+ # ------------------------------------------------------------------
31
+ # 2. Symbolic token set -------------------------------------------
32
+ # ------------------------------------------------------------------
33
+ SYMBOLIC_ROLES = [
 
 
 
34
  "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
35
  "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
36
  "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
 
39
  "<fabric>", "<jewelry>"
40
  ]
41
 
42
+ # Quick sanity check – should *never* be unk
43
+ missing = [tok for tok in SYMBOLIC_ROLES
44
+ if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
45
+ if missing:
46
+ raise RuntimeError(f"Tokenizer is missing {len(missing)} special tokens: {missing}")
47
+
48
+ # ------------------------------------------------------------------
49
+ # 3. Encoder-only inference util ----------------------------------
50
+ # ------------------------------------------------------------------
51
+ @spaces.GPU
52
+ def encode_and_trace(text: str, selected_roles: list[str]):
53
+ """
54
+ • encodes `text`
55
+ • pulls out the hidden states for any of the `selected_roles`
56
+ • returns some quick stats so we can verify everything’s wired up
57
+ """
58
+ with torch.no_grad():
59
+ batch = tokenizer(text, return_tensors="pt").to("cuda")
60
+ inp_ids, attn_mask = batch.input_ids, batch.attention_mask
61
+
62
+ # --- embedding + LayerNorm/dropout ---
63
+ x = embeddings(inp_ids)
64
+ x = emb_drop(emb_ln(x))
65
+
66
+ # --- proper *additive* attention mask ---
67
+ ext_mask = full_model.bert.get_extended_attention_mask(
68
+ attn_mask, x.shape[:-1]
69
+ )
70
+
71
+ encoded = encoder(x, attention_mask=ext_mask) # (B, S, H)
72
+
73
+ # --- pick out the positions that match selected_roles ---
74
+ sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
75
+ ids_list = inp_ids.squeeze(0).tolist() # python ints
76
+ keep_mask = torch.tensor([tid in sel_ids for tid in ids_list],
77
+ device=encoded.device)
78
+
79
+ tokens_found = [tokenizer.convert_ids_to_tokens([tid])[0]
80
+ for tid in ids_list if tid in sel_ids]
81
+ if keep_mask.any():
82
+ repr_vec = encoded.squeeze(0)[keep_mask].mean(0)
83
+ norm_val = f"{repr_vec.norm().item():.4f}"
84
+ else:
85
+ norm_val = "0.0000"
86
+
87
+ return {
88
+ "Symbolic Tokens": ", ".join(tokens_found) or "(none)",
89
+ "Embedding Norm": norm_val,
90
+ "Symbolic Token Count": int(keep_mask.sum().item()),
91
+ }
92
+
93
+ # ------------------------------------------------------------------
94
+ # 4. Gradio UI -----------------------------------------------------
95
+ # ------------------------------------------------------------------
96
  def build_interface():
97
+ with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
98
+
99
+ gr.Markdown("## 🧠 Symbolic Encoder Inspector\n"
100
+ "Paste some text containing the special `<role>` tokens and "
101
+ "inspect their encoder representations.")
102
+
103
  with gr.Row():
104
  with gr.Column():
105
+ input_text = gr.Textbox(
106
+ label="Input with Symbolic Tokens",
107
+ placeholder="Example: A <subject> wearing <upper_body_clothing> …",
108
+ lines=3,
109
+ )
110
+ role_selector = gr.CheckboxGroup(
111
+ choices=SYMBOLIC_ROLES,
112
+ label="Trace these symbolic roles"
113
  )
114
  run_btn = gr.Button("Encode & Trace")
115
  with gr.Column():
116
+ out_tokens = gr.Textbox(label="Symbolic Tokens Found")
117
+ out_norm = gr.Textbox(label="Mean Norm")
118
+ out_count = gr.Textbox(label="Token Count")
119
 
120
+ run_btn.click(
121
+ fn=encode_and_trace,
122
+ inputs=[input_text, role_selector],
123
+ outputs=[out_tokens, out_norm, out_count],
124
+ )
125
 
126
  return demo
127
 
128
+
129
  if __name__ == "__main__":
130
  demo = build_interface()
131
  demo.launch()