cifkao commited on
Commit
bb204b7
·
1 Parent(s): 535e574

Cache individual batches

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -70,14 +70,19 @@ if len(input_ids) < 2:
70
 
71
  @st.cache_data(show_spinner=False)
72
  @torch.inference_mode()
73
- def run_context_length_probing(model_name, text, window_len):
74
- assert model.name_or_path == model_name
75
- del text # needed as a cache key but for the computation we access inputs directly
 
 
 
 
 
76
 
77
  inputs_sliding = get_windows_batched(
78
- inputs,
79
  window_len=window_len,
80
- pad_id=tokenizer.eos_token_id
81
  ).convert_to_tensors("pt")
82
 
83
  logits = []
@@ -88,7 +93,13 @@ def run_context_length_probing(model_name, text, window_len):
88
  for i in range(0, num_items, batch_size):
89
  pbar.progress(i / num_items, f"{i}/{num_items}")
90
  batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
91
- logits.append(model(**batch).logits.to(torch.float16))
 
 
 
 
 
 
92
  logits = torch.cat(logits, dim=0)
93
  pbar.empty()
94
 
@@ -108,9 +119,11 @@ def run_context_length_probing(model_name, text, window_len):
108
  return scores
109
 
110
  scores = run_context_length_probing(
111
- model_name=model_name,
112
- text=text,
113
- window_len=window_len
 
 
114
  )
115
  tokens = ids_to_readable_tokens(tokenizer, input_ids)
116
 
 
70
 
71
  @st.cache_data(show_spinner=False)
72
  @torch.inference_mode()
73
+ def get_logits(_model, _inputs, cache_key):
74
+ del cache_key
75
+ return _model(**_inputs).logits.to(torch.float16)
76
+
77
+ @st.cache_data(show_spinner=False)
78
+ @torch.inference_mode()
79
+ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_key):
80
+ del cache_key
81
 
82
  inputs_sliding = get_windows_batched(
83
+ _inputs,
84
  window_len=window_len,
85
+ pad_id=_tokenizer.eos_token_id
86
  ).convert_to_tensors("pt")
87
 
88
  logits = []
 
93
  for i in range(0, num_items, batch_size):
94
  pbar.progress(i / num_items, f"{i}/{num_items}")
95
  batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
96
+ logits.append(
97
+ get_logits(
98
+ _model,
99
+ batch,
100
+ cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
101
+ )
102
+ )
103
  logits = torch.cat(logits, dim=0)
104
  pbar.empty()
105
 
 
119
  return scores
120
 
121
  scores = run_context_length_probing(
122
+ _model=model,
123
+ _tokenizer=tokenizer,
124
+ _inputs=inputs,
125
+ window_len=window_len,
126
+ cache_key=(model_name, text),
127
  )
128
  tokens = ids_to_readable_tokens(tokenizer, input_ids)
129