File size: 10,766 Bytes
3c1036c
 
 
 
119ef11
3c1036c
 
a02bd43
119ef11
a02bd43
 
 
119ef11
3c1036c
 
9570f3d
e90d7e4
 
9570f3d
e90d7e4
 
 
 
 
 
3c1036c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9570f3d
3c1036c
 
 
9570f3d
3c1036c
 
a02bd43
 
 
e90d7e4
719af86
a02bd43
3c1036c
a02bd43
 
3c1036c
719af86
a02bd43
 
 
719af86
a02bd43
719af86
a02bd43
 
 
719af86
a02bd43
 
 
 
e90d7e4
 
a02bd43
 
 
 
 
e90d7e4
 
 
 
a02bd43
719af86
 
 
e90d7e4
 
 
 
 
 
 
 
 
 
 
9570f3d
 
 
 
 
719af86
e90d7e4
3c1036c
 
 
 
 
 
719af86
 
 
 
 
 
 
9570f3d
3c1036c
719af86
3c1036c
719af86
3c1036c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719af86
 
 
9570f3d
719af86
9570f3d
3c1036c
 
 
 
 
 
 
 
a02bd43
3c1036c
 
 
 
 
 
 
 
 
 
 
a02bd43
3c1036c
 
 
 
 
 
 
 
 
 
 
a02bd43
3c1036c
 
 
 
 
 
 
e90d7e4
3c1036c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForCausalLM, attnlrp
from lxt.utils import clean_tokens
import gradio as gr
import numpy as np
import spaces
from scipy.signal import convolve2d

model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
attnlrp.register(model)


def really_clean_tokens(tokens):
    tokens = clean_tokens(tokens)
    cleaned_tokens = []
    for token in tokens:
        token = token.replace("_", " ").replace("▁", " ").replace("<s>", " ")
        if token.startswith("<0x") and token.endswith(">"):
            # Convert hex to character
            char_code = int(token[3:-1], 16)
            token = chr(char_code)
        cleaned_tokens.append(token)
    return cleaned_tokens

@spaces.GPU
def generate_and_visualize(prompt, num_tokens=10):
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
    input_embeds = model.get_input_embeddings()(input_ids)

    generated_tokens_ids = []
    all_relevances = []

    for _ in range(num_tokens):
        output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
        max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)

        max_logits.backward(max_logits)
        relevance = input_embeds.grad.float().sum(-1).cpu()[0]
        all_relevances.append(relevance)

        next_token = max_indices.unsqueeze(0)
        generated_tokens_ids.append(next_token.item())

        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        input_embeds = model.get_input_embeddings()(input_ids)
    input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
    generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
    
    return input_tokens, all_relevances, generated_tokens

def process_relevances(input_tokens, all_relevances, generated_tokens):
    attention_matrix = np.array([el[:len(all_relevances[0])] for el in all_relevances])

    ### FIND ZONES OF INTEREST
    threshold_per_token = 0.3
    kernel_width = 6
    context_width = 20  # Number of tokens to include as context on each side
    kernel = np.ones((kernel_width, kernel_width))
    
    # Compute the rolling sum using 2D convolution
    rolled_sum = convolve2d(attention_matrix, kernel, mode='valid')
    
    # Find where the rolled sum is greater than the threshold
    significant_areas = rolled_sum > kernel_width**2 * threshold_per_token

    def find_largest_contiguous_patch(array):
        current_patch_end = None
        best_width, best_patch_end = None, None
        current_width = 0
        for i in range(len(array)):
            if array[i]:
                if current_patch_end is not None and current_patch_end == i-1:
                    current_width += 1
                    current_patch_end = i
                else:
                    current_patch_end = i
                    current_width = 1
                if current_patch_end and (best_width is None or current_width > best_width):
                    best_patch_end = current_patch_end
                    best_width = current_width
            else:
                current_width = 0
        return best_width, best_patch_end

    output_with_notes = [(el, None) for el in generated_tokens[:kernel_width]]
    for row in range(kernel_width, len(generated_tokens)):
        best_width, best_patch_end = find_largest_contiguous_patch(significant_areas[row-kernel_width+1])

        if best_width is not None:
            # Fuse the notes for consecutive output tokens if necessary
            for i in range(len(output_with_notes)-2*kernel_width, len(output_with_notes)):
                token, coords = output_with_notes[i]
                if coords is not None:
                    prev_width, prev_patch_end = coords
                    if prev_patch_end > best_patch_end - best_width: # then notes are overlapping, thus we delete the first one.
                        output_with_notes[i] = (token, None)
            output_with_notes.append((generated_tokens[row], (best_width, best_patch_end)))
        else:
            output_with_notes.append((generated_tokens[row], None))

    for i, (token, coords) in enumerate(output_with_notes):
        if coords is not None:
            best_width, best_patch_end = coords
            significant_start = max(0, best_patch_end - best_width)
            significant_end = best_patch_end + kernel_width
            context_start = max(0, significant_start - context_width)
            context_end = min(len(input_tokens), significant_end + context_width)
            context = input_tokens[context_start:context_end]
            output_with_notes[i] = (token, (context, significant_start, significant_end))

    return output_with_notes

def create_html_with_hover(output_with_notes):
    html = "<div id='output-container'>"
    for i, (text, notes) in enumerate(output_with_notes):
        if notes:
            context, start, end = notes
            formatted_context = []
            for j, token in enumerate(context):
                if start <= j < end:
                    formatted_context.append(f'<strong>{token}</strong>')
                else:
                    formatted_context.append(token)
            formatted_note = "".join(formatted_context)
            html += f'<span class="hoverable" data-note-id="note-{i}">{text}<sup>[{i+1}]</sup>'
            html += f'<span class="hover-note">{formatted_note}</span></span>'
        else:
            html += f'{text}'
    html += "</div>"
    return html

@spaces.GPU
def on_generate(prompt, num_tokens):
    input_tokens, all_relevances, generated_tokens = generate_and_visualize(prompt, num_tokens)
    output_with_notes = process_relevances(input_tokens, all_relevances, generated_tokens)
    html_output = create_html_with_hover(output_with_notes)
    return html_output

css = """
#output-container { font-size: 18px; line-height: 1.5; }
.hoverable { color: blue; cursor: pointer; position: relative; }
.hover-note {
    display: none;
    position: absolute;
    padding: 5px;
    border-radius: 5px;
    bottom: 100%;
    left: 50%;
    transform: translateX(-50%);
    white-space: normal;
    background-color: rgba(240, 240, 240, 1);
    max-width: 600px;
    width:500px;
    word-wrap: break-word;
    z-index: 10;
}
.hoverable:hover .hover-note { display: block; }
"""
examples = [
    [
        """Context: Mount Everest attracts many climbers, including highly experienced mountaineers. There are two main climbing routes, one approaching the summit from the southeast in Nepal (known as the standard route) and the other from the north in Tibet. While not posing substantial technical climbing challenges on the standard route, Everest presents dangers such as altitude sickness, weather, and wind, as well as hazards from avalanches and the Khumbu Icefall. As of November 2022, 310 people have died on Everest. Over 200 bodies remain on the mountain and have not been removed due to the dangerous conditions. The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960.

Question: How high did they climb in 1922? According to the text, the 1922 expedition reached 8,""",
        40
    ],
    [
        """Hurricane Katrina killed hundreds of people as it made landfall on New Orleans in 2005 - many of these deaths could have been avoided if alerts had been given one day earlier. Accurate weather forecasts are really life-saving.

πŸ”₯ Now, NASA and IBM just dropped a game-changing new model: the first ever foundation model for weather! This means, it's the first time we have a generalist model not restricted to one task, but able to predict 160 weather variables!

Prithvi WxC (Prithvi, "ΰ€ͺΰ₯ƒΰ€₯ΰ₯ΰ€΅ΰ₯€", is the Sanskrit name for Earth) - is a 2.3 billion parameter model, with an architecture close to previous vision transformers like Hiera.

πŸ’‘ But it comes with some important tweaks: under the hood, Prithvi WxC uses a clever transformer-based architecture with 25 encoder and 5 decoder blocks. It alternates between "local" and "global" attention to capture both regional and global weather patterns.

How many weather variables can Prithvi predict? Prithvi can""",
        40
    ],
    [
        """Transformers v4.45.0 released: includes a lightning-fast method to build tools! ⚑️

During user research with colleagues @MoritzLaurer and @Jofthomas , we discovered that the class definition currently in used to define a Tool in transformers.agents is a bit tedious to use, because it goes in great detail.

➑️ So I've made an easier way to build tools: just make a function with type hints + a docstring, and add a @tool decorator in front.

βœ… VoilΓ , you're good to go!

How can you build tools simply in transformers? Just use the decorator""",
        40
    ]
]

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Token Generation with Hover Notes")
    
    input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0][0])
    num_tokens = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of tokens to generate")
    generate_button = gr.Button("Generate")

    output_html = gr.HTML(label="Generated Output")

    generate_button.click(
        on_generate,
        inputs=[input_text, num_tokens],
        outputs=[output_html]
    )
    
    gr.Markdown("Hover over the blue text with superscript numbers to see the important input tokens for that group.")

if __name__ == "__main__":
    demo.launch()