m-ric HF Staff commited on
Commit
9570f3d
Β·
1 Parent(s): 719af86

Correct context windows

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -13,9 +13,10 @@ attnlrp.register(model)
13
 
14
 
15
  def really_clean_tokens(tokens):
 
16
  cleaned_tokens = []
17
  for token in tokens:
18
- token = token.replace("_", " ").replace("▁", " ").replace("<s>", "").strip()
19
  if token.startswith("<0x") and token.endswith(">"):
20
  # Convert hex to character
21
  char_code = int(token[3:-1], 16)
@@ -44,12 +45,11 @@ def generate_and_visualize(prompt, num_tokens=10):
44
 
45
  input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
46
  input_embeds = model.get_input_embeddings()(input_ids)
47
-
48
- input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
49
- input_tokens = really_clean_tokens(input_tokens)
50
  generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
51
 
52
  return input_tokens, all_relevances, generated_tokens
 
53
  def process_relevances(input_tokens, all_relevances, generated_tokens):
54
  attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
55
 
@@ -103,11 +103,11 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
103
  for i, (token, coords) in enumerate(output_with_notes):
104
  if coords is not None:
105
  best_width, best_patch_end = coords
106
- start = max(0, best_patch_end - best_width - context_width)
107
- end = min(len(input_tokens), best_patch_end + kernel_width + context_width)
108
- context = input_tokens[start:end]
109
- significant_start = max(0, best_patch_end - best_width - start)
110
- significant_end = significant_start + best_width + kernel_width
111
  output_with_notes[i] = (token, (context, significant_start, significant_end))
112
 
113
  return output_with_notes
@@ -123,7 +123,7 @@ def create_html_with_hover(output_with_notes):
123
  formatted_context.append(f'<strong>{token}</strong>')
124
  else:
125
  formatted_context.append(token)
126
- formatted_note = " ".join(formatted_context)
127
  html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
128
  html += f'<span class="hover-note">{formatted_note}</span></span>'
129
  else:
@@ -144,7 +144,6 @@ css = """
144
  .hover-note {
145
  display: none;
146
  position: absolute;
147
- background-color: #f0f0f0;
148
  padding: 5px;
149
  border-radius: 5px;
150
  bottom: 100%;
@@ -153,8 +152,9 @@ css = """
153
  white-space: normal;
154
  background-color: rgba(240, 240, 240, 1);
155
  max-width: 600px;
 
156
  word-wrap: break-word;
157
- z-index: 1;
158
  }
159
  .hoverable:hover .hover-note { display: block; }
160
  """
 
13
 
14
 
15
  def really_clean_tokens(tokens):
16
+ tokens = clean_tokens(tokens)
17
  cleaned_tokens = []
18
  for token in tokens:
19
+ token = token.replace("_", " ").replace("▁", " ").replace("<s>", " ")
20
  if token.startswith("<0x") and token.endswith(">"):
21
  # Convert hex to character
22
  char_code = int(token[3:-1], 16)
 
45
 
46
  input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
47
  input_embeds = model.get_input_embeddings()(input_ids)
48
+ input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
 
 
49
  generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
50
 
51
  return input_tokens, all_relevances, generated_tokens
52
+
53
  def process_relevances(input_tokens, all_relevances, generated_tokens):
54
  attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
55
 
 
103
  for i, (token, coords) in enumerate(output_with_notes):
104
  if coords is not None:
105
  best_width, best_patch_end = coords
106
+ significant_start = max(0, best_patch_end - best_width)
107
+ significant_end = best_patch_end + kernel_width
108
+ context_start = max(0, significant_start - context_width)
109
+ context_end = min(len(input_tokens), significant_end + context_width)
110
+ context = input_tokens[context_start:context_end]
111
  output_with_notes[i] = (token, (context, significant_start, significant_end))
112
 
113
  return output_with_notes
 
123
  formatted_context.append(f'<strong>{token}</strong>')
124
  else:
125
  formatted_context.append(token)
126
+ formatted_note = "".join(formatted_context)
127
  html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
128
  html += f'<span class="hover-note">{formatted_note}</span></span>'
129
  else:
 
144
  .hover-note {
145
  display: none;
146
  position: absolute;
 
147
  padding: 5px;
148
  border-radius: 5px;
149
  bottom: 100%;
 
152
  white-space: normal;
153
  background-color: rgba(240, 240, 240, 1);
154
  max-width: 600px;
155
+ width:500px;
156
  word-wrap: break-word;
157
+ z-index: 10;
158
  }
159
  .hoverable:hover .hover-note { display: block; }
160
  """