AbstractPhil commited on
Commit
20331bc
Β·
verified Β·
1 Parent(s): 9228a8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -90,11 +90,12 @@ def encode_and_trace(text, selected_roles):
90
  sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
91
  sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
92
 
93
- # ========== Tokenize & Embed ==========
94
  batch = tokenizer(text, return_tensors="pt").to("cuda")
95
  ids, attn = batch.input_ids, batch.attention_mask
96
  S = ids.shape[1]
97
 
 
98
  def encode(input_ids, attn_mask):
99
  x = embeddings(input_ids)
100
  if emb_ln: x = emb_ln(x)
@@ -102,25 +103,21 @@ def encode_and_trace(text, selected_roles):
102
  ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
103
  return encoder(x, attention_mask=ext)[0]
104
 
105
- # Full unmasked encoding pass
106
  encoded = encode(ids, attn)
107
 
108
- # ========== Cosine Similarity ==========
109
- symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor)
110
-
111
  sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
112
  maxcos, argrole = sim.max(-1) # (S,)
113
  top_roles = [selected_roles[i] for i in argrole.tolist()]
114
-
115
- # ========== Sorting into High / Low Alignment Pools ==========
116
  sort_idx = maxcos.argsort(descending=True)
117
  hi_idx = sort_idx[:S // 2]
118
  lo_idx = sort_idx[S // 2:]
119
 
120
- # ========== Greedy Pool Testing ==========
121
  MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
122
 
123
- def evaluate_pool(idx_order, label):
 
124
  best_pool, best_acc = [], 0.0
125
  ptr = 0
126
  while ptr < len(idx_order):
@@ -133,15 +130,16 @@ def encode_and_trace(text, selected_roles):
133
  masked_input = ids.where(mask_flags, MASK_ID)
134
 
135
  encoded_m = encode(masked_input, attn)
136
- logits = mlm_head(encoded_m)[0] # shape: (S, V)
137
-
138
  preds = logits.argmax(-1)
139
 
140
  masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
141
  if masked_positions.numel() == 0:
142
  continue
143
 
144
- correct = (preds[masked_positions] == ids[0][masked_positions]).float()
 
 
145
  acc = correct.mean().item()
146
 
147
  if acc > best_acc:
@@ -151,17 +149,18 @@ def encode_and_trace(text, selected_roles):
151
 
152
  return best_pool, best_acc
153
 
154
- pool_hi, acc_hi = evaluate_pool(hi_idx, "high")
155
- pool_lo, acc_lo = evaluate_pool(lo_idx, "low")
 
156
 
157
- # ========== Per-token Symbolic Role Trace ==========
158
  decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
159
  role_trace = [
160
  f"{tok:<15} β†’ {role} cos={score:.4f}"
161
  for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
162
  ]
163
 
164
- # ========== JSON Result ==========
165
  res_json = {
166
  "High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
167
  "High accuracy": f"{acc_hi:.3f}",
@@ -174,6 +173,7 @@ def encode_and_trace(text, selected_roles):
174
 
175
 
176
 
 
177
  # ------------------------------------------------------------------
178
  # 4. Gradio UI -----------------------------------------------------
179
  def build_interface():
 
90
  sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
91
  sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
92
 
93
+ # Tokenize input
94
  batch = tokenizer(text, return_tensors="pt").to("cuda")
95
  ids, attn = batch.input_ids, batch.attention_mask
96
  S = ids.shape[1]
97
 
98
+ # Safe encoder forward
99
  def encode(input_ids, attn_mask):
100
  x = embeddings(input_ids)
101
  if emb_ln: x = emb_ln(x)
 
103
  ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
104
  return encoder(x, attention_mask=ext)[0]
105
 
 
106
  encoded = encode(ids, attn)
107
 
108
+ # Get raw symbolic token embeddings directly
109
+ symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor) # βœ… FIXED
 
110
  sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
111
  maxcos, argrole = sim.max(-1) # (S,)
112
  top_roles = [selected_roles[i] for i in argrole.tolist()]
 
 
113
  sort_idx = maxcos.argsort(descending=True)
114
  hi_idx = sort_idx[:S // 2]
115
  lo_idx = sort_idx[S // 2:]
116
 
 
117
  MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
118
 
119
+ # πŸ”§ Pass ids into this function
120
+ def evaluate_pool(idx_order, label, ids):
121
  best_pool, best_acc = [], 0.0
122
  ptr = 0
123
  while ptr < len(idx_order):
 
130
  masked_input = ids.where(mask_flags, MASK_ID)
131
 
132
  encoded_m = encode(masked_input, attn)
133
+ logits = mlm_head(encoded_m)[0] # βœ… FIXED β€” direct tensor
 
134
  preds = logits.argmax(-1)
135
 
136
  masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
137
  if masked_positions.numel() == 0:
138
  continue
139
 
140
+ # βœ… FIXED: indexing from explicitly passed ids
141
+ gold = ids[0][masked_positions]
142
+ correct = (preds[masked_positions] == gold).float()
143
  acc = correct.mean().item()
144
 
145
  if acc > best_acc:
 
149
 
150
  return best_pool, best_acc
151
 
152
+ # Run both pool evaluations
153
+ pool_hi, acc_hi = evaluate_pool(hi_idx, "high", ids)
154
+ pool_lo, acc_lo = evaluate_pool(lo_idx, "low", ids)
155
 
156
+ # Per-token symbolic trace
157
  decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
158
  role_trace = [
159
  f"{tok:<15} β†’ {role} cos={score:.4f}"
160
  for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
161
  ]
162
 
163
+ # Output JSON
164
  res_json = {
165
  "High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
166
  "High accuracy": f"{acc_hi:.3f}",
 
173
 
174
 
175
 
176
+
177
  # ------------------------------------------------------------------
178
  # 4. Gradio UI -----------------------------------------------------
179
  def build_interface():