Spaces:
Runtime error
Runtime error
Ok
Browse files
app.py
CHANGED
@@ -11,7 +11,6 @@ model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", t
|
|
11 |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
12 |
attnlrp.register(model)
|
13 |
|
14 |
-
|
15 |
def really_clean_tokens(tokens):
|
16 |
tokens = clean_tokens(tokens)
|
17 |
cleaned_tokens = []
|
@@ -33,7 +32,7 @@ def generate_and_visualize(prompt, num_tokens=10):
|
|
33 |
all_relevances = []
|
34 |
|
35 |
for _ in range(num_tokens):
|
36 |
-
output_logits = model(inputs_embeds=input_embeds.requires_grad_()
|
37 |
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
|
38 |
|
39 |
max_logits.backward(max_logits)
|
@@ -54,7 +53,7 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
54 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
55 |
|
56 |
### FIND ZONES OF INTEREST
|
57 |
-
threshold_per_token = 0.
|
58 |
kernel_width = 6
|
59 |
context_width = 20 # Number of tokens to include as context on each side
|
60 |
kernel = np.ones((kernel_width, kernel_width))
|
@@ -66,48 +65,58 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
66 |
significant_areas = rolled_sum > kernel_width**2 * threshold_per_token
|
67 |
|
68 |
def find_largest_contiguous_patch(array):
|
69 |
-
|
70 |
-
best_width,
|
71 |
current_width = 0
|
72 |
for i in range(len(array)):
|
73 |
if array[i]:
|
74 |
-
if
|
75 |
current_width += 1
|
76 |
-
current_patch_end = i
|
77 |
else:
|
78 |
-
|
79 |
current_width = 1
|
80 |
-
if
|
81 |
-
|
82 |
best_width = current_width
|
83 |
else:
|
84 |
current_width = 0
|
85 |
-
return best_width,
|
86 |
|
87 |
output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
|
88 |
for row in range(kernel_width, len(generated_tokens)):
|
89 |
-
best_width,
|
90 |
|
91 |
if best_width is not None:
|
92 |
-
|
93 |
-
for i in range(len(output_with_notes)-2*kernel_width, len(output_with_notes)):
|
94 |
-
token, coords = output_with_notes[i]
|
95 |
-
if coords is not None:
|
96 |
-
prev_width, prev_patch_end = coords
|
97 |
-
if prev_patch_end > best_patch_end - best_width:
|
98 |
-
# then notes are overlapping: thus we delete the first one and make the last wider if needed
|
99 |
-
output_with_notes[i] = (token, None)
|
100 |
-
if prev_patch_end - prev_width < best_patch_end - best_width:
|
101 |
-
best_width = best_patch_end - prev_patch_end - prev_width
|
102 |
-
output_with_notes.append((generated_tokens[row], (best_width, best_patch_end)))
|
103 |
else:
|
104 |
output_with_notes.append((generated_tokens[row], None))
|
105 |
|
106 |
-
|
|
|
|
|
|
|
107 |
if coords is not None:
|
108 |
-
best_width,
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
context_start = max(0, significant_start - context_width)
|
112 |
context_end = min(len(input_tokens), significant_end + context_width)
|
113 |
first_part = "".join(input_tokens[context_start:significant_start])
|
@@ -115,22 +124,27 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
115 |
final_part = "".join(input_tokens[significant_end:context_end])
|
116 |
print("KK", first_part, significant_part, final_part)
|
117 |
|
118 |
-
output_with_notes[i] = (token, (first_part, significant_part, final_part))
|
119 |
|
120 |
return output_with_notes
|
121 |
|
122 |
def create_html_with_hover(output_with_notes):
|
123 |
html = "<div id='output-container'>"
|
124 |
note_number = 0
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
first_part, significant_part, final_part = notes
|
128 |
formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
|
129 |
html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
|
130 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
131 |
note_number += 1
|
132 |
-
|
133 |
-
html += f'{text}'
|
134 |
html += "</div>"
|
135 |
return html
|
136 |
|
|
|
11 |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
12 |
attnlrp.register(model)
|
13 |
|
|
|
14 |
def really_clean_tokens(tokens):
|
15 |
tokens = clean_tokens(tokens)
|
16 |
cleaned_tokens = []
|
|
|
32 |
all_relevances = []
|
33 |
|
34 |
for _ in range(num_tokens):
|
35 |
+
output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
|
36 |
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
|
37 |
|
38 |
max_logits.backward(max_logits)
|
|
|
53 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
54 |
|
55 |
### FIND ZONES OF INTEREST
|
56 |
+
threshold_per_token = 0.25
|
57 |
kernel_width = 6
|
58 |
context_width = 20 # Number of tokens to include as context on each side
|
59 |
kernel = np.ones((kernel_width, kernel_width))
|
|
|
65 |
significant_areas = rolled_sum > kernel_width**2 * threshold_per_token
|
66 |
|
67 |
def find_largest_contiguous_patch(array):
|
68 |
+
current_patch_start = None
|
69 |
+
best_width, best_patch_start = None, None
|
70 |
current_width = 0
|
71 |
for i in range(len(array)):
|
72 |
if array[i]:
|
73 |
+
if current_patch_start is not None and current_patch_start + current_width == i:
|
74 |
current_width += 1
|
|
|
75 |
else:
|
76 |
+
current_patch_start = i
|
77 |
current_width = 1
|
78 |
+
if current_patch_start and (best_width is None or current_width > best_width):
|
79 |
+
best_patch_start = current_patch_start
|
80 |
best_width = current_width
|
81 |
else:
|
82 |
current_width = 0
|
83 |
+
return best_width, best_patch_start
|
84 |
|
85 |
output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
|
86 |
for row in range(kernel_width, len(generated_tokens)):
|
87 |
+
best_width, best_patch_start = find_largest_contiguous_patch(significant_areas[row-kernel_width+1])
|
88 |
|
89 |
if best_width is not None:
|
90 |
+
output_with_notes.append((generated_tokens[row], (best_width, best_patch_start)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
else:
|
92 |
output_with_notes.append((generated_tokens[row], None))
|
93 |
|
94 |
+
|
95 |
+
# Fuse the notes for consecutive output tokens if necessary
|
96 |
+
for i in range(len(output_with_notes)):
|
97 |
+
token, coords = output_with_notes[i]
|
98 |
if coords is not None:
|
99 |
+
best_width, best_patch_start = coords
|
100 |
+
note_width_generated = kernel_width
|
101 |
+
for next_id in output_with_notes[i+1, i+2*kernel_width]:
|
102 |
+
next_token, next_coords = output_with_notes[next_id]
|
103 |
+
if next_coords is not None:
|
104 |
+
next_width, next_patch_start = next_coords
|
105 |
+
if best_patch_start + best_width > next_patch_start:
|
106 |
+
# then notes are overlapping: thus we delete the last one and make the first wider if needed
|
107 |
+
output_with_notes[next_id] = (next_token, None)
|
108 |
+
larger_end = max(best_patch_start + best_width, next_patch_start + next_width)
|
109 |
+
best_width = larger_end - best_patch_start
|
110 |
+
note_width_generated = kernel_width + (next_id-i)
|
111 |
+
output_with_notes[i] = (token, (best_width, best_patch_start), note_width_generated)
|
112 |
+
else:
|
113 |
+
output_with_notes[i] = (token, None, None)
|
114 |
+
|
115 |
+
for i, (token, coords, width) in enumerate(output_with_notes):
|
116 |
+
if coords is not None:
|
117 |
+
best_width, best_patch_start = coords
|
118 |
+
significant_start = max(0, best_patch_start)
|
119 |
+
significant_end = best_patch_start + kernel_width + best_width
|
120 |
context_start = max(0, significant_start - context_width)
|
121 |
context_end = min(len(input_tokens), significant_end + context_width)
|
122 |
first_part = "".join(input_tokens[context_start:significant_start])
|
|
|
124 |
final_part = "".join(input_tokens[significant_end:context_end])
|
125 |
print("KK", first_part, significant_part, final_part)
|
126 |
|
127 |
+
output_with_notes[i] = (token, (first_part, significant_part, final_part), width)
|
128 |
|
129 |
return output_with_notes
|
130 |
|
131 |
def create_html_with_hover(output_with_notes):
|
132 |
html = "<div id='output-container'>"
|
133 |
note_number = 0
|
134 |
+
i = 0
|
135 |
+
while i < len(output_with_notes):
|
136 |
+
(token, notes, width) = output_with_notes[i]
|
137 |
+
if notes is None:
|
138 |
+
html += f'{token}'
|
139 |
+
i +=1
|
140 |
+
else:
|
141 |
+
text = "".join([element[0] for element in output_with_notes[i:i+width]])
|
142 |
first_part, significant_part, final_part = notes
|
143 |
formatted_note = f'{first_part}<strong>{significant_part}</strong>{final_part}'
|
144 |
html += f'<span class="hoverable" data-note-id="note-{note_number}">{text}<sup>[{note_number+1}]</sup>'
|
145 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
146 |
note_number += 1
|
147 |
+
i+=width+1
|
|
|
148 |
html += "</div>"
|
149 |
return html
|
150 |
|