Spaces:
Paused
Paused
Ok
Browse files
app.py
CHANGED
|
@@ -11,7 +11,6 @@ model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", t
|
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
| 12 |
attnlrp.register(model)
|
| 13 |
|
| 14 |
-
|
| 15 |
def really_clean_tokens(tokens):
|
| 16 |
tokens = clean_tokens(tokens)
|
| 17 |
cleaned_tokens = []
|
|
@@ -33,7 +32,7 @@ def generate_and_visualize(prompt, num_tokens=10):
|
|
| 33 |
all_relevances = []
|
| 34 |
|
| 35 |
for _ in range(num_tokens):
|
| 36 |
-
output_logits = model(inputs_embeds=input_embeds.requires_grad_()
|
| 37 |
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
|
| 38 |
|
| 39 |
max_logits.backward(max_logits)
|
|
@@ -54,7 +53,7 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
| 54 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
| 55 |
|
| 56 |
### FIND ZONES OF INTEREST
|
| 57 |
-
threshold_per_token = 0.
|
| 58 |
kernel_width = 6
|
| 59 |
context_width = 20 # Number of tokens to include as context on each side
|
| 60 |
kernel = np.ones((kernel_width, kernel_width))
|
|
@@ -66,48 +65,58 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
| 66 |
significant_areas = rolled_sum > kernel_width**2 * threshold_per_token
|
| 67 |
|
| 68 |
def find_largest_contiguous_patch(array):
|
| 69 |
-
|
| 70 |
-
best_width,
|
| 71 |
current_width = 0
|
| 72 |
for i in range(len(array)):
|
| 73 |
if array[i]:
|
| 74 |
-
if
|
| 75 |
current_width += 1
|
| 76 |
-
current_patch_end = i
|
| 77 |
else:
|
| 78 |
-
|
| 79 |
current_width = 1
|
| 80 |
-
if
|
| 81 |
-
|
| 82 |
best_width = current_width
|
| 83 |
else:
|
| 84 |
current_width = 0
|
| 85 |
-
return best_width,
|
| 86 |
|
| 87 |
output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
|
| 88 |
for row in range(kernel_width, len(generated_tokens)):
|
| 89 |
-
best_width,
|
| 90 |
|
| 91 |
if best_width is not None:
|
| 92 |
-
|
| 93 |
-
for i in range(len(output_with_notes)-2*kernel_width, len(output_with_notes)):
|
| 94 |
-
token, coords = output_with_notes[i]
|
| 95 |
-
if coords is not None:
|
| 96 |
-
prev_width, prev_patch_end = coords
|
| 97 |
-
if prev_patch_end > best_patch_end - best_width:
|
| 98 |
-
# then notes are overlapping: thus we delete the first one and make the last wider if needed
|
| 99 |
-
output_with_notes[i] = (token, None)
|
| 100 |
-
if prev_patch_end - prev_width < best_patch_end - best_width:
|
| 101 |
-
best_width = best_patch_end - prev_patch_end - prev_width
|
| 102 |
-
output_with_notes.append((generated_tokens[row], (best_width, best_patch_end)))
|
| 103 |
else:
|
| 104 |
output_with_notes.append((generated_tokens[row], None))
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
| 107 |
if coords is not None:
|
| 108 |
-
best_width,
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
context_start = max(0, significant_start - context_width)
|
| 112 |
context_end = min(len(input_tokens), significant_end + context_width)
|
| 113 |
first_part = "".join(input_tokens[context_start:significant_start])
|
|
@@ -115,22 +124,27 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
| 115 |
final_part = "".join(input_tokens[significant_end:context_end])
|
| 116 |
print("KK", first_part, significant_part, final_part)
|
| 117 |
|
| 118 |
-
output_with_notes[i] = (token, (first_part, significant_part, final_part))
|
| 119 |
|
| 120 |
return output_with_notes
|
| 121 |
|
| 122 |
def create_html_with_hover(output_with_notes):
|
| 123 |
html = "<div id='output-container'>"
|
| 124 |
note_number = 0
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
first_part, significant_part, final_part = notes
|
| 128 |
formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
|
| 129 |
html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
|
| 130 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
| 131 |
note_number += 1
|
| 132 |
-
|
| 133 |
-
html += f'{text}'
|
| 134 |
html += "</div>"
|
| 135 |
return html
|
| 136 |
|
|
|
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
| 12 |
attnlrp.register(model)
|
| 13 |
|
|
|
|
| 14 |
def really_clean_tokens(tokens):
|
| 15 |
tokens = clean_tokens(tokens)
|
| 16 |
cleaned_tokens = []
|
|
|
|
| 32 |
all_relevances = []
|
| 33 |
|
| 34 |
for _ in range(num_tokens):
|
| 35 |
+
output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
|
| 36 |
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
|
| 37 |
|
| 38 |
max_logits.backward(max_logits)
|
|
|
|
| 53 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
| 54 |
|
| 55 |
### FIND ZONES OF INTEREST
|
| 56 |
+
threshold_per_token = 0.25
|
| 57 |
kernel_width = 6
|
| 58 |
context_width = 20 # Number of tokens to include as context on each side
|
| 59 |
kernel = np.ones((kernel_width, kernel_width))
|
|
|
|
| 65 |
significant_areas = rolled_sum > kernel_width**2 * threshold_per_token
|
| 66 |
|
| 67 |
def find_largest_contiguous_patch(array):
|
| 68 |
+
current_patch_start = None
|
| 69 |
+
best_width, best_patch_start = None, None
|
| 70 |
current_width = 0
|
| 71 |
for i in range(len(array)):
|
| 72 |
if array[i]:
|
| 73 |
+
if current_patch_start is not None and current_patch_start + current_width == i:
|
| 74 |
current_width += 1
|
|
|
|
| 75 |
else:
|
| 76 |
+
current_patch_start = i
|
| 77 |
current_width = 1
|
| 78 |
+
if current_patch_start and (best_width is None or current_width > best_width):
|
| 79 |
+
best_patch_start = current_patch_start
|
| 80 |
best_width = current_width
|
| 81 |
else:
|
| 82 |
current_width = 0
|
| 83 |
+
return best_width, best_patch_start
|
| 84 |
|
| 85 |
output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
|
| 86 |
for row in range(kernel_width, len(generated_tokens)):
|
| 87 |
+
best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row-kernel_width+1])
|
| 88 |
|
| 89 |
if best_width is not None:
|
| 90 |
+
output_with_notes.append((generated_tokens[row], (best_width, best_patch_start)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
else:
|
| 92 |
output_with_notes.append((generated_tokens[row], None))
|
| 93 |
|
| 94 |
+
|
| 95 |
+
# Fuse the notes for consecutive output tokens if necessary
|
| 96 |
+
for i in range(len(output_with_notes)):
|
| 97 |
+
token, coords = output_with_notes[i]
|
| 98 |
if coords is not None:
|
| 99 |
+
best_width, best_patch_start = coords
|
| 100 |
+
note_width_generated = kernel_width
|
| 101 |
+
for next_id in output_with_notes[i+1, i+2*kernel_width]:
|
| 102 |
+
next_token, next_coords = output_with_notes[next_id]
|
| 103 |
+
if next_coords is not None:
|
| 104 |
+
next_width, next_patch_start = next_coords
|
| 105 |
+
if best_patch_start + best_width > next_patch_start:
|
| 106 |
+
# then notes are overlapping: thus we delete the last one and make the first wider if needed
|
| 107 |
+
output_with_notes[next_id] = (next_token, None)
|
| 108 |
+
larger_end = max(best_patch_start + best_width, next_patch_start + next_width)
|
| 109 |
+
best_width = larger_end - best_patch_start
|
| 110 |
+
note_width_generated = kernel_width + (next_id-i)
|
| 111 |
+
output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated)
|
| 112 |
+
else:
|
| 113 |
+
output_with_notes[i] = (token, None, None)
|
| 114 |
+
|
| 115 |
+
for i, (token, coords, width) in enumerate(output_with_notes):
|
| 116 |
+
if coords is not None:
|
| 117 |
+
best_width, best_patch_start = coords
|
| 118 |
+
significant_start = max(0, best_patch_start)
|
| 119 |
+
significant_end = best_patch_start + kernel_width + best_width
|
| 120 |
context_start = max(0, significant_start - context_width)
|
| 121 |
context_end = min(len(input_tokens), significant_end + context_width)
|
| 122 |
first_part = "".join(input_tokens[context_start:significant_start])
|
|
|
|
| 124 |
final_part = "".join(input_tokens[significant_end:context_end])
|
| 125 |
print("KK", first_part, significant_part, final_part)
|
| 126 |
|
| 127 |
+
output_with_notes[i] = (token, (first_part, significant_part, final_part), width)
|
| 128 |
|
| 129 |
return output_with_notes
|
| 130 |
|
| 131 |
def create_html_with_hover(output_with_notes):
|
| 132 |
html = "<div id='output-container'>"
|
| 133 |
note_number = 0
|
| 134 |
+
i = 0
|
| 135 |
+
while i < len(output_with_notes):
|
| 136 |
+
(token, notes, width) = output_with_notes[i]
|
| 137 |
+
if notes is None:
|
| 138 |
+
html += f'{token}'
|
| 139 |
+
i +=1
|
| 140 |
+
else:
|
| 141 |
+
text = "".join([element[0] for element in output_with_notes[i:i+width]])
|
| 142 |
first_part, significant_part, final_part = notes
|
| 143 |
formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
|
| 144 |
html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
|
| 145 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
| 146 |
note_number += 1
|
| 147 |
+
i+=width+1
|
|
|
|
| 148 |
html += "</div>"
|
| 149 |
return html
|
| 150 |
|