Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -90,7 +90,6 @@ def encode_and_trace(text: str, selected_roles: list[str]):
|
|
90 |
ids, mask = batch.input_ids, batch.attention_mask
|
91 |
|
92 |
x = emb_drop(emb_ln(embeddings(ids)))
|
93 |
-
|
94 |
ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
|
95 |
enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
|
96 |
|
@@ -98,20 +97,20 @@ def encode_and_trace(text: str, selected_roles: list[str]):
|
|
98 |
flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
|
99 |
device=enc.device)
|
100 |
|
101 |
-
|
102 |
-
|
|
|
103 |
|
104 |
if flags.any():
|
105 |
-
vec
|
106 |
norm = f"{vec.norm().item():.4f}"
|
107 |
else:
|
108 |
norm = "0.0000"
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
}
|
115 |
|
116 |
|
117 |
# ------------------------------------------------------------------
|
|
|
90 |
ids, mask = batch.input_ids, batch.attention_mask
|
91 |
|
92 |
x = emb_drop(emb_ln(embeddings(ids)))
|
|
|
93 |
ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
|
94 |
enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
|
95 |
|
|
|
97 |
flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
|
98 |
device=enc.device)
|
99 |
|
100 |
+
found_tokens = [tokenizer.convert_ids_to_tokens([tid])[0]
|
101 |
+
for tid in ids[0].tolist() if tid in sel_ids]
|
102 |
+
tokens_str = ", ".join(found_tokens) or "(none)"
|
103 |
|
104 |
if flags.any():
|
105 |
+
vec = enc[0][flags].mean(0)
|
106 |
norm = f"{vec.norm().item():.4f}"
|
107 |
else:
|
108 |
norm = "0.0000"
|
109 |
|
110 |
+
count = int(flags.sum().item())
|
111 |
+
# >>> return *three* scalars, not one dict <<<
|
112 |
+
return tokens_str, norm, count
|
113 |
+
|
|
|
114 |
|
115 |
|
116 |
# ------------------------------------------------------------------
|