Spaces:
Running
Running
Cache the results
Browse files
app.py
CHANGED
@@ -62,28 +62,40 @@ if metric_name == "KL divergence":
|
|
62 |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
|
63 |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
)
|
75 |
-
|
76 |
-
logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
|
77 |
-
logits = logits.permute(1, 0, 2)
|
78 |
-
logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
|
79 |
-
logits = logits.view(-1, logits.shape[-1])[:-window_len]
|
80 |
-
logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
|
81 |
-
|
82 |
-
scores = logits.to(torch.float32).log_softmax(dim=-1)
|
83 |
-
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
84 |
-
scores = scores.diff(dim=0).transpose(0, 1)
|
85 |
-
scores = scores.nan_to_num()
|
86 |
-
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
|
87 |
-
scores = scores.to(torch.float16)
|
88 |
|
89 |
highlighted_text_component(tokens=tokens, scores=scores.tolist())
|
|
|
62 |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
|
63 |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
64 |
|
65 |
+
@st.cache_data(show_spinner=False)
|
66 |
+
def run_context_length_probing(model_name, text, window_len):
|
67 |
+
assert model.name_or_path == model_name
|
68 |
+
|
69 |
+
inputs = tokenizer([text])
|
70 |
+
[input_ids] = inputs["input_ids"]
|
71 |
+
window_len = min(window_len, len(input_ids))
|
72 |
+
|
73 |
+
inputs_sliding = get_windows_batched(
|
74 |
+
inputs,
|
75 |
+
window_len=window_len,
|
76 |
+
pad_id=tokenizer.eos_token_id
|
77 |
+
)
|
78 |
+
with torch.inference_mode():
|
79 |
+
logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
|
80 |
+
logits = logits.permute(1, 0, 2)
|
81 |
+
logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
|
82 |
+
logits = logits.view(-1, logits.shape[-1])[:-window_len]
|
83 |
+
logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
|
84 |
|
85 |
+
scores = logits.to(torch.float32).log_softmax(dim=-1)
|
86 |
+
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
87 |
+
scores = scores.diff(dim=0).transpose(0, 1)
|
88 |
+
scores = scores.nan_to_num()
|
89 |
+
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
|
90 |
+
scores = scores.to(torch.float16)
|
91 |
+
|
92 |
+
return input_ids, scores
|
93 |
+
|
94 |
+
input_ids, scores = run_context_length_probing(
|
95 |
+
model_name=model_name,
|
96 |
+
text=text,
|
97 |
+
window_len=window_len
|
98 |
)
|
99 |
+
tokens = ids_to_readable_tokens(tokenizer, input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
highlighted_text_component(tokens=tokens, scores=scores.tolist())
|