AbstractPhil commited on
Commit
e6ff6d3
Β·
verified Β·
1 Parent(s): aaae56c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -37
app.py CHANGED
@@ -113,49 +113,41 @@ def encode_and_trace(text, selected_roles):
113
  # container for summary text
114
  report_lines = []
115
 
116
- # ------------------------------------------------------------------
117
- # Greedy pool helper – tensor-safe version
118
- # ------------------------------------------------------------------
119
- def greedy_pool(index_tensor: torch.Tensor, which: str):
 
 
120
  """
121
- index_tensor – 1-D tensor of token indices (already on CUDA)
122
- which – "low" β†’ walk upward
123
- "high" β†’ walk downward
124
- Returns (best_pool:list[int], best_acc:float)
125
  """
126
- # ---- make everything vanilla Python ints ---------------------
127
- indices = index_tensor.tolist() # e.g. [7, 10, 13, …]
128
- if which == "high":
129
- indices = indices[::-1] # reverse for top-down
130
-
131
- best_pool: list[int] = []
132
- best_acc = 0.0
133
-
134
- for i in range(0, len(indices), 2): # 2 at a time
135
- cand = indices[i : i + 2] # plain list[int]
136
- trial = best_pool + cand # grow pool
137
-
138
- # ---- build masked input ----------------------------------
139
- mask_flags = torch.ones_like(ids).bool() # mask everything
140
- mask_flags[0, trial] = False # …except the pool
141
- masked_ids = ids.where(~mask_flags, mask_token_id)
142
 
143
- # ---- second forward-pass ---------------------------------
144
- with torch.no_grad():
145
- x_m = emb_drop(emb_ln(embeddings(masked_ids)))
146
- ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1])
147
- enc_m = encoder(x_m, attention_mask=ext_m)
148
- logits = mlm_head(enc_m)[0] # (S, V)
149
 
150
- pred = logits.argmax(-1)
151
- corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item()
 
 
 
152
 
153
- if corr > best_acc:
154
- best_acc = corr
155
- best_pool = trial # accept improvement
156
- if best_acc >= 0.50:
157
- break # early exit
158
 
 
159
  return best_pool, best_acc
160
 
161
 
 
113
  # container for summary text
114
  report_lines = []
115
 
116
+ # ───────────────────────────────────────────────────────────────
117
+ # 3. Encoder-only inference util (FIXED) β”‚
118
+ # ───────────────────────────────────────────────────────────────
119
+ MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]") # <- NEW
120
+
121
+ def greedy_pool(idx_order, tag):
122
  """
123
+ idx_order : tensor of token-indices sorted hi→lo or lo→hi
124
+ tag : "high" | "low" (for the debug print)
125
+ returns : (best_pool_indices , best_accuracy)
 
126
  """
127
+ best_pool, best_acc = [], 0.0
128
+ ptr = 0
129
+ while ptr < len(idx_order):
130
+ cand = idx_order[ptr : ptr + 2] # 2-at-a-time
131
+ pool = best_pool + cand.tolist() # grow pool
132
+ ptr += 2
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # --- build *mask* for β€œeverything NOT in pool” ----------
135
+ mask_flags = torch.zeros_like(ids, dtype=torch.bool)
136
+ mask_flags[0, pool] = True # keep these un-masked
137
+ masked_ids = ids.where(mask_flags, MASK_ID) # <- uses the constant
 
 
138
 
139
+ # re-encode & score
140
+ enc_m = encode(masked_ids, mask) # helper already defined
141
+ logits = mlm_head(enc_m).logits[0] # (S, V)
142
+ preds = logits.argmax(-1)
143
+ acc = (preds[~mask_flags] == ids[0][~mask_flags]).float().mean().item()
144
 
145
+ if acc > best_acc: # accept pool only on gain
146
+ best_pool, best_acc = pool, acc
147
+ if acc >= 0.50: # early-stop rule
148
+ break
 
149
 
150
+ print(f"{tag:>4s}-pool {best_pool} acc={best_acc:.3f}")
151
  return best_pool, best_acc
152
 
153