Spaces:
Running
Running
Cache individual batches
Browse files
app.py
CHANGED
@@ -70,14 +70,19 @@ if len(input_ids) < 2:
|
|
70 |
|
71 |
@st.cache_data(show_spinner=False)
|
72 |
@torch.inference_mode()
|
73 |
-
def
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
inputs_sliding = get_windows_batched(
|
78 |
-
|
79 |
window_len=window_len,
|
80 |
-
pad_id=
|
81 |
).convert_to_tensors("pt")
|
82 |
|
83 |
logits = []
|
@@ -88,7 +93,13 @@ def run_context_length_probing(model_name, text, window_len):
|
|
88 |
for i in range(0, num_items, batch_size):
|
89 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
90 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
91 |
-
logits.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
logits = torch.cat(logits, dim=0)
|
93 |
pbar.empty()
|
94 |
|
@@ -108,9 +119,11 @@ def run_context_length_probing(model_name, text, window_len):
|
|
108 |
return scores
|
109 |
|
110 |
scores = run_context_length_probing(
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
114 |
)
|
115 |
tokens = ids_to_readable_tokens(tokenizer, input_ids)
|
116 |
|
|
|
70 |
|
71 |
@st.cache_data(show_spinner=False)
|
72 |
@torch.inference_mode()
|
73 |
+
def get_logits(_model, _inputs, cache_key):
|
74 |
+
del cache_key
|
75 |
+
return _model(**_inputs).logits.to(torch.float16)
|
76 |
+
|
77 |
+
@st.cache_data(show_spinner=False)
|
78 |
+
@torch.inference_mode()
|
79 |
+
def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_key):
|
80 |
+
del cache_key
|
81 |
|
82 |
inputs_sliding = get_windows_batched(
|
83 |
+
_inputs,
|
84 |
window_len=window_len,
|
85 |
+
pad_id=_tokenizer.eos_token_id
|
86 |
).convert_to_tensors("pt")
|
87 |
|
88 |
logits = []
|
|
|
93 |
for i in range(0, num_items, batch_size):
|
94 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
95 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
96 |
+
logits.append(
|
97 |
+
get_logits(
|
98 |
+
_model,
|
99 |
+
batch,
|
100 |
+
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
101 |
+
)
|
102 |
+
)
|
103 |
logits = torch.cat(logits, dim=0)
|
104 |
pbar.empty()
|
105 |
|
|
|
119 |
return scores
|
120 |
|
121 |
scores = run_context_length_probing(
|
122 |
+
_model=model,
|
123 |
+
_tokenizer=tokenizer,
|
124 |
+
_inputs=inputs,
|
125 |
+
window_len=window_len,
|
126 |
+
cache_key=(model_name, text),
|
127 |
)
|
128 |
tokens = ids_to_readable_tokens(tokenizer, input_ids)
|
129 |
|