AbstractPhil commited on
Commit
45ddfa3
Β·
verified Β·
1 Parent(s): 20331bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -90,24 +90,24 @@ def encode_and_trace(text, selected_roles):
90
  sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
91
  sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
92
 
93
- # Tokenize input
94
  batch = tokenizer(text, return_tensors="pt").to("cuda")
95
  ids, attn = batch.input_ids, batch.attention_mask
96
  S = ids.shape[1]
97
 
98
- # Safe encoder forward
99
  def encode(input_ids, attn_mask):
100
  x = embeddings(input_ids)
101
  if emb_ln: x = emb_ln(x)
102
  if emb_drop: x = emb_drop(x)
103
  ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
104
- return encoder(x, attention_mask=ext)[0]
105
 
106
  encoded = encode(ids, attn)
107
 
108
- # Get raw symbolic token embeddings directly
109
- symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor) # βœ… FIXED
110
- sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
111
  maxcos, argrole = sim.max(-1) # (S,)
112
  top_roles = [selected_roles[i] for i in argrole.tolist()]
113
  sort_idx = maxcos.argsort(descending=True)
@@ -116,7 +116,7 @@ def encode_and_trace(text, selected_roles):
116
 
117
  MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
118
 
119
- # πŸ”§ Pass ids into this function
120
  def evaluate_pool(idx_order, label, ids):
121
  best_pool, best_acc = [], 0.0
122
  ptr = 0
@@ -130,16 +130,17 @@ def encode_and_trace(text, selected_roles):
130
  masked_input = ids.where(mask_flags, MASK_ID)
131
 
132
  encoded_m = encode(masked_input, attn)
133
- logits = mlm_head(encoded_m)[0] # βœ… FIXED β€” direct tensor
134
- preds = logits.argmax(-1)
135
 
136
- masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
137
  if masked_positions.numel() == 0:
138
  continue
139
 
140
- # βœ… FIXED: indexing from explicitly passed ids
141
- gold = ids[0][masked_positions]
142
- correct = (preds[masked_positions] == gold).float()
 
143
  acc = correct.mean().item()
144
 
145
  if acc > best_acc:
@@ -149,18 +150,18 @@ def encode_and_trace(text, selected_roles):
149
 
150
  return best_pool, best_acc
151
 
152
- # Run both pool evaluations
153
  pool_hi, acc_hi = evaluate_pool(hi_idx, "high", ids)
154
  pool_lo, acc_lo = evaluate_pool(lo_idx, "low", ids)
155
 
156
- # Per-token symbolic trace
157
  decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
158
  role_trace = [
159
  f"{tok:<15} β†’ {role} cos={score:.4f}"
160
  for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
161
  ]
162
 
163
- # Output JSON
164
  res_json = {
165
  "High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
166
  "High accuracy": f"{acc_hi:.3f}",
@@ -174,6 +175,7 @@ def encode_and_trace(text, selected_roles):
174
 
175
 
176
 
 
177
  # ------------------------------------------------------------------
178
  # 4. Gradio UI -----------------------------------------------------
179
  def build_interface():
 
90
  sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
91
  sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
92
 
93
+ # Tokenize
94
  batch = tokenizer(text, return_tensors="pt").to("cuda")
95
  ids, attn = batch.input_ids, batch.attention_mask
96
  S = ids.shape[1]
97
 
98
+ # Encode helper
99
  def encode(input_ids, attn_mask):
100
  x = embeddings(input_ids)
101
  if emb_ln: x = emb_ln(x)
102
  if emb_drop: x = emb_drop(x)
103
  ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
104
+ return encoder(x, attention_mask=ext)[0] # shape: (1, S, H)
105
 
106
  encoded = encode(ids, attn)
107
 
108
+ # Project symbolic token embeddings
109
+ symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor) # shape: (R, H)
110
+ sim = cosine(encoded[0].unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
111
  maxcos, argrole = sim.max(-1) # (S,)
112
  top_roles = [selected_roles[i] for i in argrole.tolist()]
113
  sort_idx = maxcos.argsort(descending=True)
 
116
 
117
  MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
118
 
119
+ # Final pool evaluator
120
  def evaluate_pool(idx_order, label, ids):
121
  best_pool, best_acc = [], 0.0
122
  ptr = 0
 
130
  masked_input = ids.where(mask_flags, MASK_ID)
131
 
132
  encoded_m = encode(masked_input, attn)
133
+ logits = mlm_head(encoded_m) # (1, S, V)
134
+ preds = logits.argmax(-1) # (1, S)
135
 
136
+ masked_positions = (~mask_flags[0]).nonzero(as_tuple=True)[0] # 1D tensor
137
  if masked_positions.numel() == 0:
138
  continue
139
 
140
+ # Extract both predicted and gold tokens
141
+ pred_tokens = preds[0, masked_positions]
142
+ gold_tokens = ids[0, masked_positions]
143
+ correct = (pred_tokens == gold_tokens).float()
144
  acc = correct.mean().item()
145
 
146
  if acc > best_acc:
 
150
 
151
  return best_pool, best_acc
152
 
153
+ # Run both pools
154
  pool_hi, acc_hi = evaluate_pool(hi_idx, "high", ids)
155
  pool_lo, acc_lo = evaluate_pool(lo_idx, "low", ids)
156
 
157
+ # Alignment trace
158
  decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
159
  role_trace = [
160
  f"{tok:<15} β†’ {role} cos={score:.4f}"
161
  for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
162
  ]
163
 
164
+ # Return results
165
  res_json = {
166
  "High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
167
  "High accuracy": f"{acc_hi:.3f}",
 
175
 
176
 
177
 
178
+
179
  # ------------------------------------------------------------------
180
  # 4. Gradio UI -----------------------------------------------------
181
  def build_interface():