Update app.py
Browse files
app.py
CHANGED
@@ -138,17 +138,19 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
138 |
generation_texts = tokenizer.batch_decode(generated)
|
139 |
|
140 |
# try identifying important layers
|
141 |
-
|
142 |
-
|
|
|
143 |
bags_of_words = [set(tokenizer.tokenize(text)) for text in generation_texts]
|
144 |
-
|
145 |
-len(bags_of_words[i+1] & bags_of_words[i]) / np.sqrt(len(bags_of_words[i+1]) * len(bags_of_words[i]))
|
146 |
for i in range(len(bags_of_words)-1)
|
147 |
])
|
|
|
148 |
avoid_first, avoid_last = 2, 1 # layers that are usually never important
|
149 |
assert avoid_first >= 1 # due to .diff() we will not be able to compute a score for the first layer
|
150 |
-
diff_score = diff_score[avoid_first-1
|
151 |
-
important_idxs = avoid_first + diff_score.topk(k=int(np.ceil(0.
|
152 |
|
153 |
# create GUI output
|
154 |
print(f'{important_idxs=}')
|
@@ -248,8 +250,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
248 |
btn.render()
|
249 |
|
250 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
251 |
-
interpretation_bubbles = [gr.Textbox('', container=False, visible=False)
|
252 |
-
for i in range(MAX_NUM_LAYERS)]
|
253 |
|
254 |
# event listeners
|
255 |
for i, btn in enumerate(tokens_container):
|
|
|
138 |
generation_texts = tokenizer.batch_decode(generated)
|
139 |
|
140 |
# try identifying important layers
|
141 |
+
vectors_to_compare = interpreted_vectors # torch.tensor(global_state.sentence_transformer.encode(generation_texts))
|
142 |
+
diff_score1 = F.normalize(vectors_to_compare, dim=-1).diff(dim=0).norm(dim=-1)
|
143 |
+
|
144 |
bags_of_words = [set(tokenizer.tokenize(text)) for text in generation_texts]
|
145 |
+
diff_score2 = torch.tensor([
|
146 |
-len(bags_of_words[i+1] & bags_of_words[i]) / np.sqrt(len(bags_of_words[i+1]) * len(bags_of_words[i]))
|
147 |
for i in range(len(bags_of_words)-1)
|
148 |
])
|
149 |
+
diff_score = diff_score1 / diff_score1.median() + diff_score2 / diff_score2.median()
|
150 |
avoid_first, avoid_last = 2, 1 # layers that are usually never important
|
151 |
assert avoid_first >= 1 # due to .diff() we will not be able to compute a score for the first layer
|
152 |
+
diff_score = diff_score[avoid_first-1:len(diff_score)-avoid_last]
|
153 |
+
important_idxs = avoid_first + diff_score.topk(k=4).indices.cpu().numpy() # k=int(np.ceil(0.15 * len(generation_texts)))
|
154 |
|
155 |
# create GUI output
|
156 |
print(f'{important_idxs=}')
|
|
|
250 |
btn.render()
|
251 |
|
252 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
253 |
+
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
|
|
|
254 |
|
255 |
# event listeners
|
256 |
for i, btn in enumerate(tokens_container):
|