AbstractPhil commited on
Commit
aaae56c
·
verified ·
1 Parent(s): f22905a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -27
app.py CHANGED
@@ -112,50 +112,53 @@ def encode_and_trace(text, selected_roles):
112
 
113
  # container for summary text
114
  report_lines = []
115
-
116
  # ------------------------------------------------------------------
117
- # Greedy pool helper (drop-in replacement)
118
  # ------------------------------------------------------------------
119
- def greedy_pool(index_list, which: str):
120
  """
121
- index_list – indices (list[int]) to start pooling from
122
- which – "low" → walk upward
123
- "high" → walk downward
124
- Returns (best_pool, best_acc)
125
  """
126
- step = +1 if which == "low" else -1
127
- best_pool, best_acc = [], 0.0
 
 
 
 
 
128
 
129
- for i in range(0, len(index_list), 2): # 2 at a time
130
- # current candidate indices to keep un-masked
131
- cand = index_list[i : i + 2]
132
- pool = best_pool + cand # grow pool
133
- mask_flags = torch.ones_like(ids).bool() # mask *everything*
134
- mask_flags[0, pool] = False # ...except pool
135
- masked_ids = ids.masked_fill(~mask_flags, tokenizer.mask_token_id)
136
 
137
- # ---------- second forward-pass on MASKED input ----------
 
 
 
 
 
138
  with torch.no_grad():
139
  x_m = emb_drop(emb_ln(embeddings(masked_ids)))
140
  ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1])
141
- enc_m = encoder(x_m, attention_mask=ext_m) # (1,S,H)
142
- logits = mlm_head(enc_m).squeeze(0) # (S,V)
143
- # ---------------------------------------------------------
144
 
145
- # accuracy of predicting original tokens only at *masked* positions
146
  pred = logits.argmax(-1)
147
  corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item()
148
 
149
- if corr > best_acc: # greedy improve
150
- best_acc = corr
151
- best_pool = pool
152
-
153
- # stop early if we already exceed 0.50
154
  if best_acc >= 0.50:
155
- break
156
 
157
  return best_pool, best_acc
158
 
 
159
  pool_lo, acc_lo = greedy_pool(lo_idx, "low")
160
  pool_hi, acc_hi = greedy_pool(hi_idx, "high")
161
 
 
112
 
113
  # container for summary text
114
  report_lines = []
115
+
116
  # ------------------------------------------------------------------
117
+ # Greedy pool helper – tensor-safe version
118
  # ------------------------------------------------------------------
119
+ def greedy_pool(index_tensor: torch.Tensor, which: str):
120
  """
121
+ index_tensor 1-D tensor of token indices (already on CUDA)
122
+ which – "low" → walk upward
123
+ "high" → walk downward
124
+ Returns (best_pool:list[int], best_acc:float)
125
  """
126
+ # ---- make everything vanilla Python ints ---------------------
127
+ indices = index_tensor.tolist() # e.g. [7, 10, 13, …]
128
+ if which == "high":
129
+ indices = indices[::-1] # reverse for top-down
130
+
131
+ best_pool: list[int] = []
132
+ best_acc = 0.0
133
 
134
+ for i in range(0, len(indices), 2): # 2 at a time
135
+ cand = indices[i : i + 2] # plain list[int]
136
+ trial = best_pool + cand # grow pool
 
 
 
 
137
 
138
+ # ---- build masked input ----------------------------------
139
+ mask_flags = torch.ones_like(ids).bool() # mask everything
140
+ mask_flags[0, trial] = False # …except the pool
141
+ masked_ids = ids.where(~mask_flags, mask_token_id)
142
+
143
+ # ---- second forward-pass ---------------------------------
144
  with torch.no_grad():
145
  x_m = emb_drop(emb_ln(embeddings(masked_ids)))
146
  ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1])
147
+ enc_m = encoder(x_m, attention_mask=ext_m)
148
+ logits = mlm_head(enc_m)[0] # (S, V)
 
149
 
 
150
  pred = logits.argmax(-1)
151
  corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item()
152
 
153
+ if corr > best_acc:
154
+ best_acc = corr
155
+ best_pool = trial # accept improvement
 
 
156
  if best_acc >= 0.50:
157
+ break # early exit
158
 
159
  return best_pool, best_acc
160
 
161
+
162
  pool_lo, acc_lo = greedy_pool(lo_idx, "low")
163
  pool_hi, acc_hi = greedy_pool(hi_idx, "high")
164