AbstractPhil commited on
Commit
f22905a
·
verified ·
1 Parent(s): fd4a12a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -22
app.py CHANGED
@@ -113,29 +113,48 @@ def encode_and_trace(text, selected_roles):
113
  # container for summary text
114
  report_lines = []
115
 
116
- # ---- pool builder helper (uses S-4…S-7) ----
117
- def greedy_pool(token_indices, direction):
118
- # direction=='hi' or 'lo'
119
- pool = []
120
- μ = None
121
- for tix in token_indices:
122
- t_vec = enc[tix]
123
- # incremental update (S-7)
124
- μ = t_vec if μ is None else (μ*len(pool) + t_vec)/(len(pool)+1)
125
- pool.append(tix)
126
- # mask all tokens *except* this pool
127
- masked_ids = ids.clone()
128
- keep = torch.tensor(pool, device=ids.device)
129
- mask_mask = torch.ones_like(ids, dtype=torch.bool)
130
- mask_mask[0, keep] = False
131
- masked_ids[mask_mask] = tokenizer.mask_token_id
132
- # run MLM
 
 
 
 
 
133
  with torch.no_grad():
134
- logits = mlm_head(full_model.bert.emb_dl(enc.unsqueeze(0))).logits[0]
135
- acc = pool_accuracy(ids[0], logits, ~mask_mask[0])
136
- report_lines.append(f"{direction}-pool size {len(pool)} → acc={acc:.2f}")
137
- if acc >= 0.5: break
138
- return pool, acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  pool_lo, acc_lo = greedy_pool(lo_idx, "low")
141
  pool_hi, acc_hi = greedy_pool(hi_idx, "high")
 
113
  # container for summary text
114
  report_lines = []
115
 
116
+ # ------------------------------------------------------------------
117
+ # Greedy pool helper (drop-in replacement)
118
+ # ------------------------------------------------------------------
119
+ def greedy_pool(index_list, which: str):
120
+ """
121
+ index_list – indices (list[int]) to start pooling from
122
+ which – "low" → walk upward
123
+ "high" walk downward
124
+ Returns (best_pool, best_acc)
125
+ """
126
+ step = +1 if which == "low" else -1
127
+ best_pool, best_acc = [], 0.0
128
+
129
+ for i in range(0, len(index_list), 2): # 2 at a time
130
+ # current candidate indices to keep un-masked
131
+ cand = index_list[i : i + 2]
132
+ pool = best_pool + cand # grow pool
133
+ mask_flags = torch.ones_like(ids).bool() # mask *everything*
134
+ mask_flags[0, pool] = False # ...except pool
135
+ masked_ids = ids.masked_fill(~mask_flags, tokenizer.mask_token_id)
136
+
137
+ # ---------- second forward-pass on MASKED input ----------
138
  with torch.no_grad():
139
+ x_m = emb_drop(emb_ln(embeddings(masked_ids)))
140
+ ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1])
141
+ enc_m = encoder(x_m, attention_mask=ext_m) # (1,S,H)
142
+ logits = mlm_head(enc_m).squeeze(0) # (S,V)
143
+ # ---------------------------------------------------------
144
+
145
+ # accuracy of predicting original tokens only at *masked* positions
146
+ pred = logits.argmax(-1)
147
+ corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item()
148
+
149
+ if corr > best_acc: # greedy improve
150
+ best_acc = corr
151
+ best_pool = pool
152
+
153
+ # stop early if we already exceed 0.50
154
+ if best_acc >= 0.50:
155
+ break
156
+
157
+ return best_pool, best_acc
158
 
159
  pool_lo, acc_lo = greedy_pool(lo_idx, "low")
160
  pool_hi, acc_hi = greedy_pool(hi_idx, "high")