Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -83,89 +83,96 @@ def pool_accuracy(ids, logits, pool_mask):
|
|
83 |
return (preds==gold).float().mean().item()
|
84 |
|
85 |
|
|
|
86 |
@spaces.GPU
|
87 |
def encode_and_trace(text, selected_roles):
|
88 |
-
# if user unchecked everything we treat as "all"
|
89 |
if not selected_roles:
|
90 |
selected_roles = SYMBOLIC_ROLES
|
91 |
-
sel_ids =
|
|
|
92 |
|
93 |
-
#
|
94 |
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
95 |
-
ids,
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
sort_idx = maxcos.argsort(descending=True)
|
110 |
-
hi_idx
|
111 |
-
lo_idx
|
112 |
-
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
# 3. Encoder-only inference util (FIXED) β
|
118 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
119 |
-
MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]") # <- NEW
|
120 |
-
|
121 |
-
def greedy_pool(idx_order, tag):
|
122 |
-
"""
|
123 |
-
idx_order : tensor of token-indices sorted hiβlo or loβhi
|
124 |
-
tag : "high" | "low" (for the debug print)
|
125 |
-
returns : (best_pool_indices , best_accuracy)
|
126 |
-
"""
|
127 |
best_pool, best_acc = [], 0.0
|
128 |
ptr = 0
|
129 |
while ptr < len(idx_order):
|
130 |
-
cand = idx_order[ptr
|
131 |
-
pool = best_pool + cand.tolist()
|
132 |
ptr += 2
|
133 |
-
|
134 |
-
# --- build *mask* for βeverything NOT in poolβ ----------
|
135 |
mask_flags = torch.zeros_like(ids, dtype=torch.bool)
|
136 |
-
mask_flags[0, pool] = True
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
146 |
best_pool, best_acc = pool, acc
|
147 |
-
if acc >= 0.
|
148 |
break
|
149 |
-
|
150 |
-
print(f"{tag:>4s}-pool {best_pool} acc={best_acc:.3f}")
|
151 |
return best_pool, best_acc
|
152 |
|
|
|
|
|
153 |
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
#
|
158 |
res_json = {
|
|
|
|
|
159 |
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
|
160 |
-
"Low accuracy":
|
161 |
-
"
|
162 |
-
"High accuracy": f"{acc_hi:.2f}",
|
163 |
-
"Trace": "\n".join(report_lines)
|
164 |
}
|
165 |
-
|
166 |
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
|
167 |
|
168 |
|
|
|
169 |
# ------------------------------------------------------------------
|
170 |
# 4. Gradio UI -----------------------------------------------------
|
171 |
def build_interface():
|
|
|
83 |
return (preds==gold).float().mean().item()
|
84 |
|
85 |
|
86 |
+
@spaces.GPU
|
87 |
@spaces.GPU
|
88 |
def encode_and_trace(text, selected_roles):
|
|
|
89 |
if not selected_roles:
|
90 |
selected_roles = SYMBOLIC_ROLES
|
91 |
+
sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
|
92 |
+
sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
|
93 |
|
94 |
+
# ========== Tokenize & Embed ==========
|
95 |
batch = tokenizer(text, return_tensors="pt").to("cuda")
|
96 |
+
ids, attn = batch.input_ids, batch.attention_mask
|
97 |
+
S = ids.shape[1]
|
98 |
+
|
99 |
+
def encode(input_ids, attn_mask):
|
100 |
+
x = embeddings(input_ids)
|
101 |
+
if emb_ln: x = emb_ln(x)
|
102 |
+
if emb_drop: x = emb_drop(x)
|
103 |
+
ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
|
104 |
+
return encoder(x, attention_mask=ext)[0]
|
105 |
+
|
106 |
+
# Full unmasked encoding pass
|
107 |
+
encoded = encode(ids, attn)
|
108 |
+
|
109 |
+
# ========== Cosine Similarity ==========
|
110 |
+
symbolic_embeds = embeddings(sel_ids_tensor) # (R, H)
|
111 |
+
sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
|
112 |
+
maxcos, argrole = sim.max(-1) # (S,)
|
113 |
+
top_roles = [selected_roles[i] for i in argrole.tolist()]
|
114 |
+
|
115 |
+
# ========== Sorting into High / Low Alignment Pools ==========
|
116 |
sort_idx = maxcos.argsort(descending=True)
|
117 |
+
hi_idx = sort_idx[:S // 2]
|
118 |
+
lo_idx = sort_idx[S // 2:]
|
119 |
+
|
120 |
+
# ========== Greedy Pool Testing ==========
|
121 |
+
MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
|
122 |
+
|
123 |
+
def evaluate_pool(idx_order, label):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
best_pool, best_acc = [], 0.0
|
125 |
ptr = 0
|
126 |
while ptr < len(idx_order):
|
127 |
+
cand = idx_order[ptr:ptr + 2]
|
128 |
+
pool = best_pool + cand.tolist()
|
129 |
ptr += 2
|
130 |
+
|
|
|
131 |
mask_flags = torch.zeros_like(ids, dtype=torch.bool)
|
132 |
+
mask_flags[0, pool] = True
|
133 |
+
masked_input = ids.where(mask_flags, MASK_ID)
|
134 |
+
|
135 |
+
encoded_m = encode(masked_input, attn)
|
136 |
+
logits = mlm_head(encoded_m).logits[0]
|
137 |
+
preds = logits.argmax(-1)
|
138 |
+
|
139 |
+
masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
|
140 |
+
if masked_positions.numel() == 0:
|
141 |
+
continue
|
142 |
+
|
143 |
+
correct = (preds[masked_positions] == ids[0][masked_positions]).float()
|
144 |
+
acc = correct.mean().item()
|
145 |
+
|
146 |
+
if acc > best_acc:
|
147 |
best_pool, best_acc = pool, acc
|
148 |
+
if acc >= 0.5:
|
149 |
break
|
150 |
+
|
|
|
151 |
return best_pool, best_acc
|
152 |
|
153 |
+
pool_hi, acc_hi = evaluate_pool(hi_idx, "high")
|
154 |
+
pool_lo, acc_lo = evaluate_pool(lo_idx, "low")
|
155 |
|
156 |
+
# ========== Per-token Symbolic Role Trace ==========
|
157 |
+
decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
|
158 |
+
role_trace = [
|
159 |
+
f"{tok:<15} β {role} cos={score:.4f}"
|
160 |
+
for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
|
161 |
+
]
|
162 |
|
163 |
+
# ========== JSON Result ==========
|
164 |
res_json = {
|
165 |
+
"High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
|
166 |
+
"High accuracy": f"{acc_hi:.3f}",
|
167 |
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
|
168 |
+
"Low accuracy": f"{acc_lo:.3f}",
|
169 |
+
"TokenβSymbolic Role Alignment": role_trace
|
|
|
|
|
170 |
}
|
171 |
+
|
172 |
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
|
173 |
|
174 |
|
175 |
+
|
176 |
# ------------------------------------------------------------------
|
177 |
# 4. Gradio UI -----------------------------------------------------
|
178 |
def build_interface():
|