cifkao commited on
Commit
b6ab215
·
1 Parent(s): faa3816

Cache the results

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -62,28 +62,40 @@ if metric_name == "KL divergence":
62
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
63
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
64
 
65
- inputs = tokenizer([text])
66
- [input_ids] = inputs["input_ids"]
67
- window_len = min(window_len, len(input_ids))
68
- tokens = ids_to_readable_tokens(tokenizer, input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- inputs_sliding = get_windows_batched(
71
- inputs,
72
- window_len=window_len,
73
- pad_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
74
  )
75
- with torch.inference_mode():
76
- logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
77
- logits = logits.permute(1, 0, 2)
78
- logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
79
- logits = logits.view(-1, logits.shape[-1])[:-window_len]
80
- logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
81
-
82
- scores = logits.to(torch.float32).log_softmax(dim=-1)
83
- scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
84
- scores = scores.diff(dim=0).transpose(0, 1)
85
- scores = scores.nan_to_num()
86
- scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
87
- scores = scores.to(torch.float16)
88
 
89
  highlighted_text_component(tokens=tokens, scores=scores.tolist())
 
62
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
63
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
64
 
65
+ @st.cache_data(show_spinner=False)
66
+ def run_context_length_probing(model_name, text, window_len):
67
+ assert model.name_or_path == model_name
68
+
69
+ inputs = tokenizer([text])
70
+ [input_ids] = inputs["input_ids"]
71
+ window_len = min(window_len, len(input_ids))
72
+
73
+ inputs_sliding = get_windows_batched(
74
+ inputs,
75
+ window_len=window_len,
76
+ pad_id=tokenizer.eos_token_id
77
+ )
78
+ with torch.inference_mode():
79
+ logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
80
+ logits = logits.permute(1, 0, 2)
81
+ logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
82
+ logits = logits.view(-1, logits.shape[-1])[:-window_len]
83
+ logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
84
 
85
+ scores = logits.to(torch.float32).log_softmax(dim=-1)
86
+ scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
87
+ scores = scores.diff(dim=0).transpose(0, 1)
88
+ scores = scores.nan_to_num()
89
+ scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
90
+ scores = scores.to(torch.float16)
91
+
92
+ return input_ids, scores
93
+
94
+ input_ids, scores = run_context_length_probing(
95
+ model_name=model_name,
96
+ text=text,
97
+ window_len=window_len
98
  )
99
+ tokens = ids_to_readable_tokens(tokenizer, input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  highlighted_text_component(tokens=tokens, scores=scores.tolist())