m-ric HF Staff commited on
Commit
96879fc
Β·
1 Parent(s): 2343db7
Files changed (1) hide show
  1. app.py +46 -32
app.py CHANGED
@@ -11,7 +11,6 @@ model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", t
11
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
12
  attnlrp.register(model)
13
 
14
-
15
  def really_clean_tokens(tokens):
16
  tokens = clean_tokens(tokens)
17
  cleaned_tokens = []
@@ -33,7 +32,7 @@ def generate_and_visualize(prompt, num_tokens=10):
33
  all_relevances = []
34
 
35
  for _ in range(num_tokens):
36
- output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
37
  max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
38
 
39
  max_logits.backward(max_logits)
@@ -54,7 +53,7 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
54
  attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
55
 
56
  ### FIND ZONES OF INTEREST
57
- threshold_per_token = 0.3
58
  kernel_width = 6
59
  context_width = 20 # Number of tokens to include as context on each side
60
  kernel = np.ones((kernel_width, kernel_width))
@@ -66,48 +65,58 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
66
  significant_areas = rolled_sum > kernel_width**2 * threshold_per_token
67
 
68
  def find_largest_contiguous_patch(array):
69
- current_patch_end = None
70
- best_width, best_patch_end = None, None
71
  current_width = 0
72
  for i in range(len(array)):
73
  if array[i]:
74
- if current_patch_end is not None and current_patch_end == i-1:
75
  current_width += 1
76
- current_patch_end = i
77
  else:
78
- current_patch_end = i
79
  current_width = 1
80
- if current_patch_end and (best_width is None or current_width > best_width):
81
- best_patch_end = current_patch_end
82
  best_width = current_width
83
  else:
84
  current_width = 0
85
- return best_width, best_patch_end
86
 
87
  output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
88
  for row in range(kernel_width, len(generated_tokens)):
89
- best_width, best_patch_end = find_largest_contiguous_patch(significant_areas[row-kernel_width+1])
90
 
91
  if best_width is not None:
92
- # Fuse the notes for consecutive output tokens if necessary
93
- for i in range(len(output_with_notes)-2*kernel_width, len(output_with_notes)):
94
- token, coords = output_with_notes[i]
95
- if coords is not None:
96
- prev_width, prev_patch_end = coords
97
- if prev_patch_end > best_patch_end - best_width:
98
- # then notes are overlapping: thus we delete the first one and make the last wider if needed
99
- output_with_notes[i] = (token, None)
100
- if prev_patch_end - prev_width < best_patch_end - best_width:
101
- best_width = best_patch_end - prev_patch_end - prev_width
102
- output_with_notes.append((generated_tokens[row], (best_width, best_patch_end)))
103
  else:
104
  output_with_notes.append((generated_tokens[row], None))
105
 
106
- for i, (token, coords) in enumerate(output_with_notes):
 
 
 
107
  if coords is not None:
108
- best_width, best_patch_end = coords
109
- significant_start = max(0, best_patch_end - best_width)
110
- significant_end = best_patch_end + kernel_width
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  context_start = max(0, significant_start - context_width)
112
  context_end = min(len(input_tokens), significant_end + context_width)
113
  first_part = "".join(input_tokens[context_start:significant_start])
@@ -115,22 +124,27 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
115
  final_part = "".join(input_tokens[significant_end:context_end])
116
  print("KK", first_part, significant_part, final_part)
117
 
118
- output_with_notes[i] = (token, (first_part, significant_part, final_part))
119
 
120
  return output_with_notes
121
 
122
  def create_html_with_hover(output_with_notes):
123
  html = "<div id='output-container'>"
124
  note_number = 0
125
- for (text, notes) in output_with_notes:
126
- if notes:
 
 
 
 
 
 
127
  first_part, significant_part, final_part = notes
128
  formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
129
  html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
130
  html += f'<span class="hover-note">{formatted_note}</span></span>'
131
  note_number += 1
132
- else:
133
- html += f'{text}'
134
  html += "</div>"
135
  return html
136
 
 
11
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
12
  attnlrp.register(model)
13
 
 
14
  def really_clean_tokens(tokens):
15
  tokens = clean_tokens(tokens)
16
  cleaned_tokens = []
 
32
  all_relevances = []
33
 
34
  for _ in range(num_tokens):
35
+ output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
36
  max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
37
 
38
  max_logits.backward(max_logits)
 
53
  attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
54
 
55
  ### FIND ZONES OF INTEREST
56
+ threshold_per_token = 0.25
57
  kernel_width = 6
58
  context_width = 20 # Number of tokens to include as context on each side
59
  kernel = np.ones((kernel_width, kernel_width))
 
65
  significant_areas = rolled_sum > kernel_width**2 * threshold_per_token
66
 
67
  def find_largest_contiguous_patch(array):
68
+ current_patch_start = None
69
+ best_width, best_patch_start = None, None
70
  current_width = 0
71
  for i in range(len(array)):
72
  if array[i]:
73
+ if current_patch_start is not None and current_patch_start + current_width == i:
74
  current_width += 1
 
75
  else:
76
+ current_patch_start = i
77
  current_width = 1
78
+ if current_patch_start and (best_width is None or current_width > best_width):
79
+ best_patch_start = current_patch_start
80
  best_width = current_width
81
  else:
82
  current_width = 0
83
+ return best_width, best_patch_start
84
 
85
  output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
86
  for row in range(kernel_width, len(generated_tokens)):
87
+ best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row-kernel_width+1])
88
 
89
  if best_width is not None:
90
+ output_with_notes.append((generated_tokens[row], (best_width, best_patch_start)))
 
 
 
 
 
 
 
 
 
 
91
  else:
92
  output_with_notes.append((generated_tokens[row], None))
93
 
94
+
95
+ # Fuse the notes for consecutive output tokens if necessary
96
+ for i in range(len(output_with_notes)):
97
+ token, coords = output_with_notes[i]
98
  if coords is not None:
99
+ best_width, best_patch_start = coords
100
+ note_width_generated = kernel_width
101
+ for next_id in output_with_notes[i+1, i+2*kernel_width]:
102
+ next_token, next_coords = output_with_notes[next_id]
103
+ if next_coords is not None:
104
+ next_width, next_patch_start = next_coords
105
+ if best_patch_start + best_width > next_patch_start:
106
+ # then notes are overlapping: thus we delete the last one and make the first wider if needed
107
+ output_with_notes[next_id] = (next_token, None)
108
+ larger_end = max(best_patch_start + best_width, next_patch_start + next_width)
109
+ best_width = larger_end - best_patch_start
110
+ note_width_generated = kernel_width + (next_id-i)
111
+ output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated)
112
+ else:
113
+ output_with_notes[i] = (token, None, None)
114
+
115
+ for i, (token, coords, width) in enumerate(output_with_notes):
116
+ if coords is not None:
117
+ best_width, best_patch_start = coords
118
+ significant_start = max(0, best_patch_start)
119
+ significant_end = best_patch_start + kernel_width + best_width
120
  context_start = max(0, significant_start - context_width)
121
  context_end = min(len(input_tokens), significant_end + context_width)
122
  first_part = "".join(input_tokens[context_start:significant_start])
 
124
  final_part = "".join(input_tokens[significant_end:context_end])
125
  print("KK", first_part, significant_part, final_part)
126
 
127
+ output_with_notes[i] = (token, (first_part, significant_part, final_part), width)
128
 
129
  return output_with_notes
130
 
131
  def create_html_with_hover(output_with_notes):
132
  html = "<div id='output-container'>"
133
  note_number = 0
134
+ i = 0
135
+ while i < len(output_with_notes):
136
+ (token, notes, width) = output_with_notes[i]
137
+ if notes is None:
138
+ html += f'{token}'
139
+ i +=1
140
+ else:
141
+ text = "".join([element[0] for element in output_with_notes[i:i+width]])
142
  first_part, significant_part, final_part = notes
143
  formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
144
  html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
145
  html += f'<span class="hover-note">{formatted_note}</span></span>'
146
  note_number += 1
147
+ i+=width+1
 
148
  html += "</div>"
149
  return html
150