AbstractPhil commited on
Commit
b235205
·
verified ·
1 Parent(s): 45ddfa3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -74
app.py CHANGED
@@ -87,93 +87,57 @@ def pool_accuracy(ids, logits, pool_mask):
87
  def encode_and_trace(text, selected_roles):
88
  if not selected_roles:
89
  selected_roles = SYMBOLIC_ROLES
 
 
90
  sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
91
- sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
92
 
93
- # Tokenize
94
  batch = tokenizer(text, return_tensors="pt").to("cuda")
95
- ids, attn = batch.input_ids, batch.attention_mask
96
- S = ids.shape[1]
97
 
98
- # Encode helper
99
  def encode(input_ids, attn_mask):
100
- x = embeddings(input_ids)
101
  if emb_ln: x = emb_ln(x)
102
  if emb_drop: x = emb_drop(x)
103
- ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
104
- return encoder(x, attention_mask=ext)[0] # shape: (1, S, H)
105
-
106
- encoded = encode(ids, attn)
107
-
108
- # Project symbolic token embeddings
109
- symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor) # shape: (R, H)
110
- sim = cosine(encoded[0].unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
111
- maxcos, argrole = sim.max(-1) # (S,)
112
- top_roles = [selected_roles[i] for i in argrole.tolist()]
113
- sort_idx = maxcos.argsort(descending=True)
114
- hi_idx = sort_idx[:S // 2]
115
- lo_idx = sort_idx[S // 2:]
116
-
117
- MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
118
-
119
- # Final pool evaluator
120
- def evaluate_pool(idx_order, label, ids):
121
- best_pool, best_acc = [], 0.0
122
- ptr = 0
123
- while ptr < len(idx_order):
124
- cand = idx_order[ptr:ptr + 2]
125
- pool = best_pool + cand.tolist()
126
- ptr += 2
127
-
128
- mask_flags = torch.zeros_like(ids, dtype=torch.bool)
129
- mask_flags[0, pool] = True
130
- masked_input = ids.where(mask_flags, MASK_ID)
131
-
132
- encoded_m = encode(masked_input, attn)
133
- logits = mlm_head(encoded_m) # (1, S, V)
134
- preds = logits.argmax(-1) # (1, S)
135
-
136
- masked_positions = (~mask_flags[0]).nonzero(as_tuple=True)[0] # 1D tensor
137
- if masked_positions.numel() == 0:
138
- continue
139
-
140
- # Extract both predicted and gold tokens
141
- pred_tokens = preds[0, masked_positions]
142
- gold_tokens = ids[0, masked_positions]
143
- correct = (pred_tokens == gold_tokens).float()
144
- acc = correct.mean().item()
145
-
146
- if acc > best_acc:
147
- best_pool, best_acc = pool, acc
148
- if acc >= 0.5:
149
- break
150
-
151
- return best_pool, best_acc
152
-
153
- # Run both pools
154
- pool_hi, acc_hi = evaluate_pool(hi_idx, "high", ids)
155
- pool_lo, acc_lo = evaluate_pool(lo_idx, "low", ids)
156
-
157
- # Alignment trace
158
- decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
159
  role_trace = [
160
- f"{tok:<15} → {role} cos={score:.4f}"
161
- for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
162
  ]
163
 
164
- # Return results
165
  res_json = {
166
- "High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
167
- "High accuracy": f"{acc_hi:.3f}",
168
- "Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
169
- "Low accuracy": f"{acc_lo:.3f}",
170
- "Token–Symbolic Role Alignment": role_trace
171
  }
172
 
173
- return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
174
-
175
-
176
-
177
 
178
 
179
  # ------------------------------------------------------------------
 
87
  def encode_and_trace(text, selected_roles):
88
  if not selected_roles:
89
  selected_roles = SYMBOLIC_ROLES
90
+
91
+ # Convert symbolic role tokens to IDs
92
  sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
93
+ sel_ids_tensor = torch.tensor(sel_ids, device="cuda").unsqueeze(0) # shape: (1, R)
94
 
95
+ # Tokenize user prompt
96
  batch = tokenizer(text, return_tensors="pt").to("cuda")
97
+ input_ids, attention_mask = batch.input_ids, batch.attention_mask
98
+ S = input_ids.shape[1]
99
 
100
+ # === Shared encoder logic with RoPE ===
101
  def encode(input_ids, attn_mask):
102
+ x = embeddings(input_ids) # (B, S, H)
103
  if emb_ln: x = emb_ln(x)
104
  if emb_drop: x = emb_drop(x)
105
+ ext = full_model.bert.get_extended_attention_mask(attn_mask, input_ids.shape)
106
+ return encoder(x, attention_mask=ext)[0] # (B, S, H)
107
+
108
+ # Encode prompt
109
+ encoded_prompt = encode(input_ids, attention_mask)[0] # (S, H)
110
+
111
+ # Encode symbolic roles through same pipeline
112
+ symbolic_attn = torch.ones_like(sel_ids_tensor)
113
+ encoded_roles = encode(sel_ids_tensor, symbolic_attn)[0] # (R, H)
114
+
115
+ # === Symbolic classification via cosine similarity ===
116
+ # Compare each token to each symbolic role → shape: (S, R)
117
+ token_exp = encoded_prompt.unsqueeze(1).expand(-1, encoded_roles.size(0), -1) # (S, R, H)
118
+ role_exp = encoded_roles.unsqueeze(0).expand(encoded_prompt.size(0), -1, -1) # (S, R, H)
119
+ sim = F.cosine_similarity(token_exp, role_exp, dim=-1) # → (S, R)
120
+
121
+ argmax_ids = sim.argmax(dim=-1) # (S,)
122
+ max_scores = sim.max(dim=-1).values # (S,)
123
+ predicted_roles = [selected_roles[i] for i in argmax_ids.tolist()]
124
+ decoded_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
125
+
126
+ # === Build readable trace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  role_trace = [
128
+ f"{tok:<15} → {role:<22} score={score:.4f}"
129
+ for tok, role, score in zip(decoded_tokens, predicted_roles, max_scores.tolist())
130
  ]
131
 
132
+ # === Final output
133
  res_json = {
134
+ "Prompt": text,
135
+ "Predicted symbolic roles": predicted_roles,
136
+ "Max alignment score": f"{max_scores.max().item():.4f}",
137
+ "Per-token classification": role_trace
 
138
  }
139
 
140
+ return json.dumps(res_json, indent=2), f"{max_scores.max().item():.4f}", len(selected_roles)
 
 
 
141
 
142
 
143
  # ------------------------------------------------------------------