AbstractPhil commited on
Commit
6cdb3e9
Β·
verified Β·
1 Parent(s): e6ff6d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -59
app.py CHANGED
@@ -83,89 +83,96 @@ def pool_accuracy(ids, logits, pool_mask):
83
  return (preds==gold).float().mean().item()
84
 
85
 
 
86
  @spaces.GPU
87
  def encode_and_trace(text, selected_roles):
88
- # if user unchecked everything we treat as "all"
89
  if not selected_roles:
90
  selected_roles = SYMBOLIC_ROLES
91
- sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
 
92
 
93
- # ---- Tokenise & encode once ----
94
  batch = tokenizer(text, return_tensors="pt").to("cuda")
95
- ids, att = batch.input_ids, batch.attention_mask
96
- x = emb_drop(emb_ln(embeddings(ids)))
97
- ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1])
98
- enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H)
99
-
100
- # ---- compute max-cos per token (F-0/F-1) ----
101
- role_mat = embeddings.word_embeddings(
102
- torch.tensor(sorted(sel_ids), device=enc.device)
103
- ) # (R,H)
104
- cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R)
105
- maxcos, argrole = cos.max(-1) # (S,)
106
-
107
- # ---- split tokens into High / Low half (F-2) ----
108
- S = len(ids[0])
 
 
 
 
 
 
109
  sort_idx = maxcos.argsort(descending=True)
110
- hi_idx = sort_idx[: S//2]
111
- lo_idx = sort_idx[S//2:]
112
-
113
- # container for summary text
114
- report_lines = []
115
-
116
- # ───────────────────────────────────────────────────────────────
117
- # 3. Encoder-only inference util (FIXED) β”‚
118
- # ───────────────────────────────────────────────────────────────
119
- MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]") # <- NEW
120
-
121
- def greedy_pool(idx_order, tag):
122
- """
123
- idx_order : tensor of token-indices sorted hi→lo or lo→hi
124
- tag : "high" | "low" (for the debug print)
125
- returns : (best_pool_indices , best_accuracy)
126
- """
127
  best_pool, best_acc = [], 0.0
128
  ptr = 0
129
  while ptr < len(idx_order):
130
- cand = idx_order[ptr : ptr + 2] # 2-at-a-time
131
- pool = best_pool + cand.tolist() # grow pool
132
  ptr += 2
133
-
134
- # --- build *mask* for β€œeverything NOT in pool” ----------
135
  mask_flags = torch.zeros_like(ids, dtype=torch.bool)
136
- mask_flags[0, pool] = True # keep these un-masked
137
- masked_ids = ids.where(mask_flags, MASK_ID) # <- uses the constant
138
-
139
- # re-encode & score
140
- enc_m = encode(masked_ids, mask) # helper already defined
141
- logits = mlm_head(enc_m).logits[0] # (S, V)
142
- preds = logits.argmax(-1)
143
- acc = (preds[~mask_flags] == ids[0][~mask_flags]).float().mean().item()
144
-
145
- if acc > best_acc: # accept pool only on gain
 
 
 
 
 
146
  best_pool, best_acc = pool, acc
147
- if acc >= 0.50: # early-stop rule
148
  break
149
-
150
- print(f"{tag:>4s}-pool {best_pool} acc={best_acc:.3f}")
151
  return best_pool, best_acc
152
 
 
 
153
 
154
- pool_lo, acc_lo = greedy_pool(lo_idx, "low")
155
- pool_hi, acc_hi = greedy_pool(hi_idx, "high")
 
 
 
 
156
 
157
- # ---- package textual result ----
158
  res_json = {
 
 
159
  "Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
160
- "Low accuracy": f"{acc_lo:.2f}",
161
- "High-pool tokens":tokenizer.decode(ids[0, pool_hi]),
162
- "High accuracy": f"{acc_hi:.2f}",
163
- "Trace": "\n".join(report_lines)
164
  }
165
- # three outputs expected by UI
166
  return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
167
 
168
 
 
169
  # ------------------------------------------------------------------
170
  # 4. Gradio UI -----------------------------------------------------
171
  def build_interface():
 
83
  return (preds==gold).float().mean().item()
84
 
85
 
86
+ @spaces.GPU
87
  @spaces.GPU
88
  def encode_and_trace(text, selected_roles):
 
89
  if not selected_roles:
90
  selected_roles = SYMBOLIC_ROLES
91
+ sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
92
+ sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
93
 
94
+ # ========== Tokenize & Embed ==========
95
  batch = tokenizer(text, return_tensors="pt").to("cuda")
96
+ ids, attn = batch.input_ids, batch.attention_mask
97
+ S = ids.shape[1]
98
+
99
+ def encode(input_ids, attn_mask):
100
+ x = embeddings(input_ids)
101
+ if emb_ln: x = emb_ln(x)
102
+ if emb_drop: x = emb_drop(x)
103
+ ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
104
+ return encoder(x, attention_mask=ext)[0]
105
+
106
+ # Full unmasked encoding pass
107
+ encoded = encode(ids, attn)
108
+
109
+ # ========== Cosine Similarity ==========
110
+ symbolic_embeds = embeddings(sel_ids_tensor) # (R, H)
111
+ sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
112
+ maxcos, argrole = sim.max(-1) # (S,)
113
+ top_roles = [selected_roles[i] for i in argrole.tolist()]
114
+
115
+ # ========== Sorting into High / Low Alignment Pools ==========
116
  sort_idx = maxcos.argsort(descending=True)
117
+ hi_idx = sort_idx[:S // 2]
118
+ lo_idx = sort_idx[S // 2:]
119
+
120
+ # ========== Greedy Pool Testing ==========
121
+ MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
122
+
123
+ def evaluate_pool(idx_order, label):
 
 
 
 
 
 
 
 
 
 
124
  best_pool, best_acc = [], 0.0
125
  ptr = 0
126
  while ptr < len(idx_order):
127
+ cand = idx_order[ptr:ptr + 2]
128
+ pool = best_pool + cand.tolist()
129
  ptr += 2
130
+
 
131
  mask_flags = torch.zeros_like(ids, dtype=torch.bool)
132
+ mask_flags[0, pool] = True
133
+ masked_input = ids.where(mask_flags, MASK_ID)
134
+
135
+ encoded_m = encode(masked_input, attn)
136
+ logits = mlm_head(encoded_m).logits[0]
137
+ preds = logits.argmax(-1)
138
+
139
+ masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
140
+ if masked_positions.numel() == 0:
141
+ continue
142
+
143
+ correct = (preds[masked_positions] == ids[0][masked_positions]).float()
144
+ acc = correct.mean().item()
145
+
146
+ if acc > best_acc:
147
  best_pool, best_acc = pool, acc
148
+ if acc >= 0.5:
149
  break
150
+
 
151
  return best_pool, best_acc
152
 
153
+ pool_hi, acc_hi = evaluate_pool(hi_idx, "high")
154
+ pool_lo, acc_lo = evaluate_pool(lo_idx, "low")
155
 
156
+ # ========== Per-token Symbolic Role Trace ==========
157
+ decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
158
+ role_trace = [
159
+ f"{tok:<15} β†’ {role} cos={score:.4f}"
160
+ for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
161
+ ]
162
 
163
+ # ========== JSON Result ==========
164
  res_json = {
165
+ "High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
166
+ "High accuracy": f"{acc_hi:.3f}",
167
  "Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
168
+ "Low accuracy": f"{acc_lo:.3f}",
169
+ "Token–Symbolic Role Alignment": role_trace
 
 
170
  }
171
+
172
  return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
173
 
174
 
175
+
176
  # ------------------------------------------------------------------
177
  # 4. Gradio UI -----------------------------------------------------
178
  def build_interface():