dar-tau commited on
Commit
0a22698
·
verified ·
1 Parent(s): a5aded9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -138,17 +138,19 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
138
  generation_texts = tokenizer.batch_decode(generated)
139
 
140
  # try identifying important layers
141
- # vectors_to_compare = interpreted_vectors # torch.tensor(global_state.sentence_transformer.encode(generation_texts))
142
- # diff_score = F.normalize(vectors_to_compare, dim=-1).diff(dim=0).norm(dim=-1)
 
143
  bags_of_words = [set(tokenizer.tokenize(text)) for text in generation_texts]
144
- diff_score = torch.tensor([
145
  -len(bags_of_words[i+1] & bags_of_words[i]) / np.sqrt(len(bags_of_words[i+1]) * len(bags_of_words[i]))
146
  for i in range(len(bags_of_words)-1)
147
  ])
 
148
  avoid_first, avoid_last = 2, 1 # layers that are usually never important
149
  assert avoid_first >= 1 # due to .diff() we will not be able to compute a score for the first layer
150
- diff_score = diff_score[avoid_first-1 : len(diff_score)-avoid_last]
151
- important_idxs = avoid_first + diff_score.topk(k=int(np.ceil(0.1 * len(generation_texts)))).indices.cpu().numpy()
152
 
153
  # create GUI output
154
  print(f'{important_idxs=}')
@@ -248,8 +250,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
248
  btn.render()
249
 
250
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
251
- interpretation_bubbles = [gr.Textbox('', container=False, visible=False)
252
- for i in range(MAX_NUM_LAYERS)]
253
 
254
  # event listeners
255
  for i, btn in enumerate(tokens_container):
 
138
  generation_texts = tokenizer.batch_decode(generated)
139
 
140
  # try identifying important layers
141
+ vectors_to_compare = interpreted_vectors # torch.tensor(global_state.sentence_transformer.encode(generation_texts))
142
+ diff_score1 = F.normalize(vectors_to_compare, dim=-1).diff(dim=0).norm(dim=-1)
143
+
144
  bags_of_words = [set(tokenizer.tokenize(text)) for text in generation_texts]
145
+ diff_score2 = torch.tensor([
146
  -len(bags_of_words[i+1] & bags_of_words[i]) / np.sqrt(len(bags_of_words[i+1]) * len(bags_of_words[i]))
147
  for i in range(len(bags_of_words)-1)
148
  ])
149
+ diff_score = diff_score1 / diff_score1.median() + diff_score2 / diff_score2.median()
150
  avoid_first, avoid_last = 2, 1 # layers that are usually never important
151
  assert avoid_first >= 1 # due to .diff() we will not be able to compute a score for the first layer
152
+ diff_score = diff_score[avoid_first-1:len(diff_score)-avoid_last]
153
+ important_idxs = avoid_first + diff_score.topk(k=4).indices.cpu().numpy() # k=int(np.ceil(0.15 * len(generation_texts)))
154
 
155
  # create GUI output
156
  print(f'{important_idxs=}')
 
250
  btn.render()
251
 
252
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
253
+ interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
 
254
 
255
  # event listeners
256
  for i, btn in enumerate(tokens_container):