AbstractPhil commited on
Commit
b60c583
·
verified ·
1 Parent(s): 872b08b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -90,7 +90,6 @@ def encode_and_trace(text: str, selected_roles: list[str]):
90
  ids, mask = batch.input_ids, batch.attention_mask
91
 
92
  x = emb_drop(emb_ln(embeddings(ids)))
93
-
94
  ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
95
  enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
96
 
@@ -98,20 +97,20 @@ def encode_and_trace(text: str, selected_roles: list[str]):
98
  flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
99
  device=enc.device)
100
 
101
- found = [tokenizer.convert_ids_to_tokens([tid])[0]
102
- for tid in ids[0].tolist() if tid in sel_ids]
 
103
 
104
  if flags.any():
105
- vec = enc[0][flags].mean(0)
106
  norm = f"{vec.norm().item():.4f}"
107
  else:
108
  norm = "0.0000"
109
 
110
- return {
111
- "Symbolic Tokens": ", ".join(found) or "(none)",
112
- "Embedding Norm": norm,
113
- "Symbolic Token Count": int(flags.sum().item()),
114
- }
115
 
116
 
117
  # ------------------------------------------------------------------
 
90
  ids, mask = batch.input_ids, batch.attention_mask
91
 
92
  x = emb_drop(emb_ln(embeddings(ids)))
 
93
  ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
94
  enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
95
 
 
97
  flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
98
  device=enc.device)
99
 
100
+ found_tokens = [tokenizer.convert_ids_to_tokens([tid])[0]
101
+ for tid in ids[0].tolist() if tid in sel_ids]
102
+ tokens_str = ", ".join(found_tokens) or "(none)"
103
 
104
  if flags.any():
105
+ vec = enc[0][flags].mean(0)
106
  norm = f"{vec.norm().item():.4f}"
107
  else:
108
  norm = "0.0000"
109
 
110
+ count = int(flags.sum().item())
111
+ # >>> return *three* scalars, not one dict <<<
112
+ return tokens_str, norm, count
113
+
 
114
 
115
 
116
  # ------------------------------------------------------------------