Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -113,49 +113,41 @@ def encode_and_trace(text, selected_roles):
|
|
113 |
# container for summary text
|
114 |
report_lines = []
|
115 |
|
116 |
-
#
|
117 |
-
#
|
118 |
-
#
|
119 |
-
|
|
|
|
|
120 |
"""
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
Returns (best_pool:list[int], best_acc:float)
|
125 |
"""
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
best_acc = 0.0
|
133 |
-
|
134 |
-
for i in range(0, len(indices), 2): # 2 at a time
|
135 |
-
cand = indices[i : i + 2] # plain list[int]
|
136 |
-
trial = best_pool + cand # grow pool
|
137 |
-
|
138 |
-
# ---- build masked input ----------------------------------
|
139 |
-
mask_flags = torch.ones_like(ids).bool() # mask everything
|
140 |
-
mask_flags[0, trial] = False # β¦except the pool
|
141 |
-
masked_ids = ids.where(~mask_flags, mask_token_id)
|
142 |
|
143 |
-
#
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
enc_m = encoder(x_m, attention_mask=ext_m)
|
148 |
-
logits = mlm_head(enc_m)[0] # (S, V)
|
149 |
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
152 |
|
153 |
-
if
|
154 |
-
best_acc
|
155 |
-
|
156 |
-
|
157 |
-
break # early exit
|
158 |
|
|
|
159 |
return best_pool, best_acc
|
160 |
|
161 |
|
|
|
113 |
# container for summary text
|
114 |
report_lines = []
|
115 |
|
116 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
117 |
+
# 3. Encoder-only inference util (FIXED) β
|
118 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
119 |
+
MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]") # <- NEW
|
120 |
+
|
121 |
+
def greedy_pool(idx_order, tag):
|
122 |
"""
|
123 |
+
idx_order : tensor of token-indices sorted hiβlo or loβhi
|
124 |
+
tag : "high" | "low" (for the debug print)
|
125 |
+
returns : (best_pool_indices , best_accuracy)
|
|
|
126 |
"""
|
127 |
+
best_pool, best_acc = [], 0.0
|
128 |
+
ptr = 0
|
129 |
+
while ptr < len(idx_order):
|
130 |
+
cand = idx_order[ptr : ptr + 2] # 2-at-a-time
|
131 |
+
pool = best_pool + cand.tolist() # grow pool
|
132 |
+
ptr += 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
# --- build *mask* for βeverything NOT in poolβ ----------
|
135 |
+
mask_flags = torch.zeros_like(ids, dtype=torch.bool)
|
136 |
+
mask_flags[0, pool] = True # keep these un-masked
|
137 |
+
masked_ids = ids.where(mask_flags, MASK_ID) # <- uses the constant
|
|
|
|
|
138 |
|
139 |
+
# re-encode & score
|
140 |
+
enc_m = encode(masked_ids, mask) # helper already defined
|
141 |
+
logits = mlm_head(enc_m).logits[0] # (S, V)
|
142 |
+
preds = logits.argmax(-1)
|
143 |
+
acc = (preds[~mask_flags] == ids[0][~mask_flags]).float().mean().item()
|
144 |
|
145 |
+
if acc > best_acc: # accept pool only on gain
|
146 |
+
best_pool, best_acc = pool, acc
|
147 |
+
if acc >= 0.50: # early-stop rule
|
148 |
+
break
|
|
|
149 |
|
150 |
+
print(f"{tag:>4s}-pool {best_pool} acc={best_acc:.3f}")
|
151 |
return best_pool, best_acc
|
152 |
|
153 |
|