Spaces:
Runtime error
Runtime error
Correct context windows
Browse files
app.py
CHANGED
@@ -13,9 +13,10 @@ attnlrp.register(model)
|
|
13 |
|
14 |
|
15 |
def really_clean_tokens(tokens):
|
|
|
16 |
cleaned_tokens = []
|
17 |
for token in tokens:
|
18 |
-
token = token.replace("_", " ").replace("β", " ").replace("<s>", "")
|
19 |
if token.startswith("<0x") and token.endswith(">"):
|
20 |
# Convert hex to character
|
21 |
char_code = int(token[3:-1], 16)
|
@@ -44,12 +45,11 @@ def generate_and_visualize(prompt, num_tokens=10):
|
|
44 |
|
45 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
46 |
input_embeds = model.get_input_embeddings()(input_ids)
|
47 |
-
|
48 |
-
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
49 |
-
input_tokens = really_clean_tokens(input_tokens)
|
50 |
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
|
51 |
|
52 |
return input_tokens, all_relevances, generated_tokens
|
|
|
53 |
def process_relevances(input_tokens, all_relevances, generated_tokens):
|
54 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
55 |
|
@@ -103,11 +103,11 @@ def process_relevances(input_tokens, all_relevances, generated_tokens):
|
|
103 |
for i, (token, coords) in enumerate(output_with_notes):
|
104 |
if coords is not None:
|
105 |
best_width, best_patch_end = coords
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
output_with_notes[i] = (token, (context, significant_start, significant_end))
|
112 |
|
113 |
return output_with_notes
|
@@ -123,7 +123,7 @@ def create_html_with_hover(output_with_notes):
|
|
123 |
formatted_context.append(f'<strong>{token}</strong>')
|
124 |
else:
|
125 |
formatted_context.append(token)
|
126 |
-
formatted_note = "
|
127 |
html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
|
128 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
129 |
else:
|
@@ -144,7 +144,6 @@ css = """
|
|
144 |
.hover-note {
|
145 |
display: none;
|
146 |
position: absolute;
|
147 |
-
background-color: #f0f0f0;
|
148 |
padding: 5px;
|
149 |
border-radius: 5px;
|
150 |
bottom: 100%;
|
@@ -153,8 +152,9 @@ css = """
|
|
153 |
white-space: normal;
|
154 |
background-color: rgba(240, 240, 240, 1);
|
155 |
max-width: 600px;
|
|
|
156 |
word-wrap: break-word;
|
157 |
-
z-index:
|
158 |
}
|
159 |
.hoverable:hover .hover-note { display: block; }
|
160 |
"""
|
|
|
13 |
|
14 |
|
15 |
def really_clean_tokens(tokens):
|
16 |
+
tokens = clean_tokens(tokens)
|
17 |
cleaned_tokens = []
|
18 |
for token in tokens:
|
19 |
+
token = token.replace("_", " ").replace("β", " ").replace("<s>", " ")
|
20 |
if token.startswith("<0x") and token.endswith(">"):
|
21 |
# Convert hex to character
|
22 |
char_code = int(token[3:-1], 16)
|
|
|
45 |
|
46 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
47 |
input_embeds = model.get_input_embeddings()(input_ids)
|
48 |
+
input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
|
|
|
|
|
49 |
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
|
50 |
|
51 |
return input_tokens, all_relevances, generated_tokens
|
52 |
+
|
53 |
def process_relevances(input_tokens, all_relevances, generated_tokens):
|
54 |
attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])
|
55 |
|
|
|
103 |
for i, (token, coords) in enumerate(output_with_notes):
|
104 |
if coords is not None:
|
105 |
best_width, best_patch_end = coords
|
106 |
+
significant_start = max(0, best_patch_end - best_width)
|
107 |
+
significant_end = best_patch_end + kernel_width
|
108 |
+
context_start = max(0, significant_start - context_width)
|
109 |
+
context_end = min(len(input_tokens), significant_end + context_width)
|
110 |
+
context = input_tokens[context_start:context_end]
|
111 |
output_with_notes[i] = (token, (context, significant_start, significant_end))
|
112 |
|
113 |
return output_with_notes
|
|
|
123 |
formatted_context.append(f'<strong>{token}</strong>')
|
124 |
else:
|
125 |
formatted_context.append(token)
|
126 |
+
formatted_note = "".join(formatted_context)
|
127 |
html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
|
128 |
html += f'<span class="hover-note">{formatted_note}</span></span>'
|
129 |
else:
|
|
|
144 |
.hover-note {
|
145 |
display: none;
|
146 |
position: absolute;
|
|
|
147 |
padding: 5px;
|
148 |
border-radius: 5px;
|
149 |
bottom: 100%;
|
|
|
152 |
white-space: normal;
|
153 |
background-color: rgba(240, 240, 240, 1);
|
154 |
max-width: 600px;
|
155 |
+
width:500px;
|
156 |
word-wrap: break-word;
|
157 |
+
z-index: 10;
|
158 |
}
|
159 |
.hoverable:hover .hover-note { display: block; }
|
160 |
"""
|