Lora commited on
Commit
77e63f5
·
1 Parent(s): 26a15da

visualize senses in language modeling

Browse files
Files changed (1) hide show
  1. app.py +444 -31
app.py CHANGED
@@ -1,13 +1,23 @@
1
  import torch
2
- import pandas as pd
3
  import transformers
 
 
4
  import gradio as gr
5
 
 
6
  tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
7
- vecs = torch.load("senses/all_vecs_mtx.pt")
8
- lm_head = torch.load("senses/lm_head.pt")
9
 
 
 
 
 
 
10
 
 
 
 
11
  def visualize_word(word, count=10, remove_space=False):
12
 
13
  if not remove_space:
@@ -19,14 +29,17 @@ def visualize_word(word, count=10, remove_space=False):
19
  tokens = ", ".join(tokens) # display tokenization for user
20
  print(f"Tokenized as: {tokens}")
21
  # look up sense vectors only for the first token
22
- contents = vecs[token_ids[0]] # torch.Size([16, 768])
 
 
 
23
 
24
  # for pos and neg respectively, create a list (for each sense) of list (top k) of tuples (word, logit)
25
  pos_word_lists = []
26
  neg_word_lists = []
27
  sense_names = [] # column header
28
- for i in range(contents.shape[0]):
29
- logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
30
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
31
  sense_names.append('sense {}'.format(i))
32
 
@@ -54,30 +67,430 @@ def visualize_word(word, count=10, remove_space=False):
54
 
55
  return pos_df, neg_df, tokens
56
 
57
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  gr.Markdown("""
59
- ## Backpack visualization: senses lookup
60
- > Note: Backpack uses the GPT-2 tokenizer, which includes the space before a word as part of the token, so by default, a space character `' '` is added to the beginning of the word you look up. You can disable this by checking `Remove space before word`, but know this might cause strange behaviors like breaking `afraid` into `af` and `raid`, or `slight` into `s` and `light`.
61
  """)
62
- with gr.Row():
63
- word = gr.Textbox(label="Word")
64
- token_breakdown = gr.Textbox(label="Token Breakdown (senses are for the first token only)")
65
- remove_space = gr.Checkbox(label="Remove space before word", default=False)
66
- count = gr.Slider(minimum=1, maximum=20, value=10, label="Top K", step=1)
67
- pos_outputs = gr.Dataframe(label="Highest Scoring Senses")
68
- neg_outputs = gr.Dataframe(label="Lowest Scoring Senses")
69
- gr.Examples(
70
- examples=["science", "afraid", "book", "slight"],
71
- inputs=[word],
72
- outputs=[pos_outputs, neg_outputs, token_breakdown],
73
- fn=visualize_word,
74
- cache_examples=True,
75
- )
76
-
77
- gr.Button("Look up").click(
78
- fn=visualize_word,
79
- inputs= [word, count, remove_space],
80
- outputs= [pos_outputs, neg_outputs, token_breakdown],
81
- )
82
-
83
- demo.launch(auth=("caesar", "wins"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  import transformers
3
+ from transformers import AutoModelForCausalLM
4
+ import pandas as pd
5
  import gradio as gr
6
 
7
+ # Build model & get some layers
8
  tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
9
+ m = AutoModelForCausalLM.from_pretrained("lora-x/backpack-gpt2", trust_remote_code=True)
10
+ m.eval()
11
 
12
+ lm_head = m.get_lm_head() # (V, d)
13
+ word_embeddings = m.backpack.get_word_embeddings() # (V, d)
14
+ sense_network = m.backpack.get_sense_network() # (V, nv, d)
15
+ num_senses = m.backpack.get_num_senses()
16
+ sense_names = [i for i in range(num_senses)]
17
 
18
+ """
19
+ Single token sense lookup
20
+ """
21
  def visualize_word(word, count=10, remove_space=False):
22
 
23
  if not remove_space:
 
29
  tokens = ", ".join(tokens) # display tokenization for user
30
  print(f"Tokenized as: {tokens}")
31
  # look up sense vectors only for the first token
32
+ # contents = vecs[token_ids[0]] # torch.Size([16, 768])
33
+ sense_input_embeds = word_embeddings(torch.tensor([token_ids[0]]).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim
34
+ senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d)
35
+ senses = torch.squeeze(senses) # (nv, s=1, d)
36
 
37
  # for pos and neg respectively, create a list (for each sense) of list (top k) of tuples (word, logit)
38
  pos_word_lists = []
39
  neg_word_lists = []
40
  sense_names = [] # column header
41
+ for i in range(senses.shape[0]):
42
+ logits = lm_head(senses[i,:])
43
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
44
  sense_names.append('sense {}'.format(i))
45
 
 
67
 
68
  return pos_df, neg_df, tokens
69
 
70
+ """
71
+ Returns:
72
+ - tokens: the tokenization of the input sentence, also used as options to choose from for get_token_contextual_weights
73
+ - top_k_words_df: a dataframe of the top k words predicted by the model
74
+ - length: of the input sentence, stored as a gr.State variable so other methods can find the
75
+ contextualization weights for the *last* token that's needed
76
+ - contextualization_weights: gr.State variable, stores the contextualization weights for the input sentence
77
+ """
78
+ def predict_next_word (sentence, top_k = 5, contextualization_weights = None):
79
+
80
+ # For better tokenization, by default, adds a space at the beginning of the sentence if it doesn't already have one
81
+ # and remove trailing space
82
+ sentence = sentence.strip()
83
+ if sentence[0] != ' ':
84
+ sentence = ' ' + sentence
85
+ print(f"Sentence: '{sentence}'")
86
+
87
+ # Make input, keeping track of original length
88
+ token_ids = tokenizer(sentence)['input_ids']
89
+ tokens = [[tokenizer.decode(token_id) for token_id in token_ids]] # a list of a single list because used as dataframe
90
+ length = len(token_ids)
91
+ inp = torch.zeros((1,512)).long()
92
+ inp[0,:length] = torch.tensor(token_ids).long()
93
+
94
+ # Get output at correct index
95
+ if contextualization_weights is None:
96
+ print("contextualization_weights IS None, freshly computing contextualization_weights")
97
+ output = m(inp)
98
+ logits, contextualization_weights = output.logits[0,length-1,:], output.contextualization
99
+ # Store contextualization weights and return it as a gr.State var for use by get_token_contextual_weights
100
+ else:
101
+ print("contextualization_weights is NOT None, using passed in contextualization_weights")
102
+ output = m.run_with_custom_contextualization(inp, contextualization_weights)
103
+ logits = output.logits[0,length-1,:]
104
+ probs = logits.softmax(dim=-1) # probs over next word
105
+ probs, indices = torch.sort(probs, descending=True)
106
+ top_k_words = [(tokenizer.decode(indices[i]), round(probs[i].item(), 3)) for i in range(top_k)]
107
+ top_k_words_df = pd.DataFrame(top_k_words, columns=['word', 'probability'], index=range(1, top_k+1))
108
+
109
+ top_k_words_df = top_k_words_df.T
110
+
111
+ print(top_k_words_df)
112
+
113
+ return tokens, top_k_words_df, length, contextualization_weights
114
+
115
+
116
+ """
117
+ Returns a dataframe of senses with weights for the selected token.
118
+
119
+ Args:
120
+ contextualization_weights: a gr.State variable that stores the contextualization weights for the input sentence.
121
+ length: length of the input sentence, used to get the contextualization weights for the last token
122
+ token: the selected token
123
+ token_index: the index of the selected token in the input sentence
124
+ count: how many top words to display for each sense
125
+ """
126
+ def get_token_contextual_weights (contextualization_weights, length, token, token_index, count = 7):
127
+ print(">>>>>in get_token_contextual_weights")
128
+ print(f"Selected {token_index}th token: {token}")
129
+
130
+ # get contextualization weights for the selected token
131
+ # Only care about the weights for the last word, since that's what contributes to the output
132
+ token_contextualization_weights = contextualization_weights[0, :, length-1, token_index]
133
+ token_contextualization_weights_list = [round(x, 3) for x in token_contextualization_weights.tolist()]
134
+
135
+ # get sense vectors of the selected token
136
+ token_ids = tokenizer(token)['input_ids'] # keep as a list bc sense_network expects s dim
137
+ sense_input_embeds = word_embeddings(torch.tensor(token_ids).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim
138
+ senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d)
139
+ senses = torch.squeeze(senses) # (nv, s=1, d)
140
+
141
+ # build dataframe
142
+ neg_word_lists = []
143
+ pos_dfs, neg_dfs = [], []
144
+
145
+ for i in range(num_senses):
146
+ logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
147
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
148
+
149
+ pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
150
+ pos_df = pd.DataFrame(pos_sorted_words)
151
+ pos_dfs.append(pos_df)
152
+
153
+ neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(count)]
154
+ neg_df = pd.DataFrame(neg_sorted_words)
155
+ neg_dfs.append(neg_df)
156
+
157
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
158
+ sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
159
+ sense12words, sense13words, sense14words, sense15words = pos_dfs
160
+
161
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
162
+ sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
163
+ sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
164
+
165
+ return token, token_index, sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, \
166
+ sense7words, sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
167
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
168
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
169
+
170
+ """
171
+ Wrapper for when the user selects a new token in the tokens dataframe.
172
+ Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
173
+ """
174
+ def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, count = 7):
175
+ print(">>>>>in new_token_contextual_weights")
176
+ token_index = evt.index[1] # selected token is the token_index-th token in the sentence
177
+ token = evt.value
178
+ if not token:
179
+ return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
180
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
181
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
182
+ return get_token_contextual_weights (contextualization_weights, length, token, token_index, count)
183
+
184
+ def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
185
+ contextualization_weights[0, 0, length-1, token_index] = new_weight
186
+ return contextualization_weights
187
+ def change_sense1_weight(contextualization_weights, length, token_index, new_weight):
188
+ contextualization_weights[0, 1, length-1, token_index] = new_weight
189
+ return contextualization_weights
190
+ def change_sense2_weight(contextualization_weights, length, token_index, new_weight):
191
+ contextualization_weights[0, 2, length-1, token_index] = new_weight
192
+ return contextualization_weights
193
+ def change_sense3_weight(contextualization_weights, length, token_index, new_weight):
194
+ contextualization_weights[0, 3, length-1, token_index] = new_weight
195
+ return contextualization_weights
196
+ def change_sense4_weight(contextualization_weights, length, token_index, new_weight):
197
+ contextualization_weights[0, 4, length-1, token_index] = new_weight
198
+ return contextualization_weights
199
+ def change_sense5_weight(contextualization_weights, length, token_index, new_weight):
200
+ contextualization_weights[0, 5, length-1, token_index] = new_weight
201
+ return contextualization_weights
202
+ def change_sense6_weight(contextualization_weights, length, token_index, new_weight):
203
+ contextualization_weights[0, 6, length-1, token_index] = new_weight
204
+ return contextualization_weights
205
+ def change_sense7_weight(contextualization_weights, length, token_index, new_weight):
206
+ contextualization_weights[0, 7, length-1, token_index] = new_weight
207
+ return contextualization_weights
208
+ def change_sense8_weight(contextualization_weights, length, token_index, new_weight):
209
+ contextualization_weights[0, 8, length-1, token_index] = new_weight
210
+ return contextualization_weights
211
+ def change_sense9_weight(contextualization_weights, length, token_index, new_weight):
212
+ contextualization_weights[0, 9, length-1, token_index] = new_weight
213
+ return contextualization_weights
214
+ def change_sense10_weight(contextualization_weights, length, token_index, new_weight):
215
+ contextualization_weights[0, 10, length-1, token_index] = new_weight
216
+ return contextualization_weights
217
+ def change_sense11_weight(contextualization_weights, length, token_index, new_weight):
218
+ contextualization_weights[0, 11, length-1, token_index] = new_weight
219
+ return contextualization_weights
220
+ def change_sense12_weight(contextualization_weights, length, token_index, new_weight):
221
+ contextualization_weights[0, 12, length-1, token_index] = new_weight
222
+ return contextualization_weights
223
+ def change_sense13_weight(contextualization_weights, length, token_index, new_weight):
224
+ contextualization_weights[0, 13, length-1, token_index] = new_weight
225
+ return contextualization_weights
226
+ def change_sense14_weight(contextualization_weights, length, token_index, new_weight):
227
+ contextualization_weights[0, 14, length-1, token_index] = new_weight
228
+ return contextualization_weights
229
+ def change_sense15_weight(contextualization_weights, length, token_index, new_weight):
230
+ contextualization_weights[0, 15, length-1, token_index] = new_weight
231
+ return contextualization_weights
232
+
233
+ """
234
+ Clears all gr.State variables used to store info across methods when the input sentence changes.
235
+ """
236
+ def clear_states(contextualization_weights, token_index, length):
237
+ contextualization_weights = None
238
+ token_index = None
239
+ length = 0
240
+ return contextualization_weights, token_index, length
241
+
242
+ def reset_weights(contextualization_weights):
243
+ print("Resetting weights...")
244
+ contextualization_weights = None
245
+ return contextualization_weights
246
+
247
+ with gr.Blocks( css = """#sense0slider, #sense1slider, #sense2slider, #sense3slider, #sense4slider, #sense5slider, #sense6slider, #sense7slider,
248
+ #sense8slider, #sense9slider, #sense1slider0, #sense11slider, #sense12slider, #sense13slider, #sense14slider, #sense15slider
249
+ { height: 200px; width: 200px; transform: rotate(270deg); }""" ) as demo:
250
+
251
  gr.Markdown("""
252
+ ## Backpack Sense Visualization
 
253
  """)
254
+
255
+ with gr.Tab("Language Modeling"):
256
+ contextualization_weights = gr.State(None) # store session data for sharing between functions
257
+ token_index = gr.State(None)
258
+ length = gr.State(0)
259
+ top_k = gr.State(10)
260
+ with gr.Row():
261
+ with gr.Column(scale=8):
262
+ input_sentence = gr.Textbox(label="Input Sentence", placeholder='Enter a sentence and click "Predict next word"')
263
+ with gr.Column(scale=1):
264
+ predict = gr.Button(value="Predict next word", variant="primary")
265
+ reset_weights_button = gr.Button("Reset weights")
266
+ top_k_words = gr.Dataframe(label="Next Word Predictions (top k)")
267
+ gr.Markdown("""### **Tokens:** click on a token to see its senses and contextualization weights""")
268
+ tokens = gr.DataFrame(label="")
269
+ with gr.Row():
270
+ with gr.Column(scale=1):
271
+ selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
272
+ with gr.Column(scale=8):
273
+ gr.Markdown("""#####
274
+ Once a token is chosen, you can use the sliders below to change the weights of any senses for that token, \
275
+ and then click "Predict next word" to see updated next-word predictions. \
276
+ You can change the weights of *multiple senses of multiple tokens*; \
277
+ changes will be preserved until you click "Reset weights".
278
+ """)
279
+ # sense sliders and top sense words dataframes
280
+ with gr.Row():
281
+ with gr.Column(scale=0, min_width=120):
282
+ sense0slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 0", elem_id="sense0slider", interactive=True)
283
+ with gr.Column(scale=0, min_width=120):
284
+ sense1slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 1", elem_id="sense1slider", interactive=True)
285
+ with gr.Column(scale=0, min_width=120):
286
+ sense2slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 2", elem_id="sense2slider", interactive=True)
287
+ with gr.Column(scale=0, min_width=120):
288
+ sense3slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 3", elem_id="sense3slider", interactive=True)
289
+ with gr.Column(scale=0, min_width=120):
290
+ sense4slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 4", elem_id="sense4slider", interactive=True)
291
+ with gr.Column(scale=0, min_width=120):
292
+ sense5slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 5", elem_id="sense5slider", interactive=True)
293
+ with gr.Column(scale=0, min_width=120):
294
+ sense6slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 6", elem_id="sense6slider", interactive=True)
295
+ with gr.Column(scale=0, min_width=120):
296
+ sense7slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 7", elem_id="sense7slider", interactive=True)
297
+ with gr.Row():
298
+ with gr.Column(scale=0, min_width=120):
299
+ sense0words = gr.DataFrame()
300
+ with gr.Column(scale=0, min_width=120):
301
+ sense1words = gr.DataFrame()
302
+ with gr.Column(scale=0, min_width=120):
303
+ sense2words = gr.DataFrame()
304
+ with gr.Column(scale=0, min_width=120):
305
+ sense3words = gr.DataFrame()
306
+ with gr.Column(scale=0, min_width=120):
307
+ sense4words = gr.DataFrame()
308
+ with gr.Column(scale=0, min_width=120):
309
+ sense5words = gr.DataFrame()
310
+ with gr.Column(scale=0, min_width=120):
311
+ sense6words = gr.DataFrame()
312
+ with gr.Column(scale=0, min_width=120):
313
+ sense7words = gr.DataFrame()
314
+ with gr.Row():
315
+ with gr.Column(scale=0, min_width=120):
316
+ sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
317
+ with gr.Column(scale=0, min_width=120):
318
+ sense9slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 9", elem_id="sense9slider", interactive=True)
319
+ with gr.Column(scale=0, min_width=120):
320
+ sense10slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 10", elem_id="sense1slider0", interactive=True)
321
+ with gr.Column(scale=0, min_width=120):
322
+ sense11slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 11", elem_id="sense11slider", interactive=True)
323
+ with gr.Column(scale=0, min_width=120):
324
+ sense12slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 12", elem_id="sense12slider", interactive=True)
325
+ with gr.Column(scale=0, min_width=120):
326
+ sense13slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 13", elem_id="sense13slider", interactive=True)
327
+ with gr.Column(scale=0, min_width=120):
328
+ sense14slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 14", elem_id="sense14slider", interactive=True)
329
+ with gr.Column(scale=0, min_width=120):
330
+ sense15slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 15", elem_id="sense15slider", interactive=True)
331
+ with gr.Row():
332
+ with gr.Column(scale=0, min_width=120):
333
+ sense8words = gr.DataFrame()
334
+ with gr.Column(scale=0, min_width=120):
335
+ sense9words = gr.DataFrame()
336
+ with gr.Column(scale=0, min_width=120):
337
+ sense10words = gr.DataFrame()
338
+ with gr.Column(scale=0, min_width=120):
339
+ sense11words = gr.DataFrame()
340
+ with gr.Column(scale=0, min_width=120):
341
+ sense12words = gr.DataFrame()
342
+ with gr.Column(scale=0, min_width=120):
343
+ sense13words = gr.DataFrame()
344
+ with gr.Column(scale=0, min_width=120):
345
+ sense14words = gr.DataFrame()
346
+ with gr.Column(scale=0, min_width=120):
347
+ sense15words = gr.DataFrame()
348
+
349
+ # gr.Examples(
350
+ # examples=[["Messi plays for", top_k, None]],
351
+ # inputs=[input_sentence, top_k, contextualization_weights],
352
+ # outputs=[tokens, top_k_words, length, contextualization_weights],
353
+ # fn=predict_next_word,
354
+ # )
355
+
356
+ sense0slider.change(fn=change_sense0_weight,
357
+ inputs=[contextualization_weights, length, token_index, sense0slider],
358
+ outputs=[contextualization_weights])
359
+ sense1slider.change(fn=change_sense1_weight,
360
+ inputs=[contextualization_weights, length, token_index, sense1slider],
361
+ outputs=[contextualization_weights])
362
+ sense2slider.change(fn=change_sense2_weight,
363
+ inputs=[contextualization_weights, length, token_index, sense2slider],
364
+ outputs=[contextualization_weights])
365
+ sense3slider.change(fn=change_sense3_weight,
366
+ inputs=[contextualization_weights, length, token_index, sense3slider],
367
+ outputs=[contextualization_weights])
368
+ sense4slider.change(fn=change_sense4_weight,
369
+ inputs=[contextualization_weights, length, token_index, sense4slider],
370
+ outputs=[contextualization_weights])
371
+ sense5slider.change(fn=change_sense5_weight,
372
+ inputs=[contextualization_weights, length, token_index, sense5slider],
373
+ outputs=[contextualization_weights])
374
+ sense6slider.change(fn=change_sense6_weight,
375
+ inputs=[contextualization_weights, length, token_index, sense6slider],
376
+ outputs=[contextualization_weights])
377
+ sense7slider.change(fn=change_sense7_weight,
378
+ inputs=[contextualization_weights, length, token_index, sense7slider],
379
+ outputs=[contextualization_weights])
380
+ sense8slider.change(fn=change_sense8_weight,
381
+ inputs=[contextualization_weights, length, token_index, sense8slider],
382
+ outputs=[contextualization_weights])
383
+ sense9slider.change(fn=change_sense9_weight,
384
+ inputs=[contextualization_weights, length, token_index, sense9slider],
385
+ outputs=[contextualization_weights])
386
+ sense10slider.change(fn=change_sense10_weight,
387
+ inputs=[contextualization_weights, length, token_index, sense10slider],
388
+ outputs=[contextualization_weights])
389
+ sense11slider.change(fn=change_sense11_weight,
390
+ inputs=[contextualization_weights, length, token_index, sense11slider],
391
+ outputs=[contextualization_weights])
392
+ sense12slider.change(fn=change_sense12_weight,
393
+ inputs=[contextualization_weights, length, token_index, sense12slider],
394
+ outputs=[contextualization_weights])
395
+ sense13slider.change(fn=change_sense13_weight,
396
+ inputs=[contextualization_weights, length, token_index, sense13slider],
397
+ outputs=[contextualization_weights])
398
+ sense14slider.change(fn=change_sense14_weight,
399
+ inputs=[contextualization_weights, length, token_index, sense14slider],
400
+ outputs=[contextualization_weights])
401
+ sense15slider.change(fn=change_sense15_weight,
402
+ inputs=[contextualization_weights, length, token_index, sense15slider],
403
+ outputs=[contextualization_weights])
404
+
405
+ predict.click(
406
+ fn=predict_next_word,
407
+ inputs = [input_sentence, top_k, contextualization_weights],
408
+ outputs= [tokens, top_k_words, length, contextualization_weights],
409
+ )
410
+
411
+ tokens.select(fn=new_token_contextual_weights,
412
+ inputs=[contextualization_weights, length],
413
+ outputs= [selected_token, token_index,
414
+
415
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
416
+ sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
417
+
418
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
419
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
420
+ )
421
+
422
+ reset_weights_button.click(
423
+ fn=reset_weights,
424
+ inputs=[contextualization_weights],
425
+ outputs=[contextualization_weights]
426
+ ).success(
427
+ fn=predict_next_word,
428
+ inputs = [input_sentence, top_k, contextualization_weights],
429
+ outputs= [tokens, top_k_words, length, contextualization_weights],
430
+ ).success(
431
+ fn=get_token_contextual_weights,
432
+ inputs=[contextualization_weights, length, selected_token, token_index],
433
+ outputs= [selected_token, token_index,
434
+
435
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
436
+ sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
437
+
438
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
439
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
440
+ )
441
+
442
+ input_sentence.change(
443
+ fn=clear_states,
444
+ inputs=[contextualization_weights, token_index, length],
445
+ outputs=[contextualization_weights, token_index, length]
446
+ )
447
+
448
+ with gr.Tab("Individual Word Sense Look Up"):
449
+ gr.Markdown("""> Note on tokenization: Backpack uses the GPT-2 tokenizer, which includes the space before a word as part \
450
+ of the token, so by default, a space character `' '` is added to the beginning of the word \
451
+ you look up. You can disable this by checking `Remove space before word`, but know this might \
452
+ cause strange behaviors like breaking `afraid` into `af` and `raid`, or `slight` into `s` and `light`.
453
+ """)
454
+ with gr.Row():
455
+ word = gr.Textbox(label="Word", placeholder="e.g. science")
456
+ token_breakdown = gr.Textbox(label="Token Breakdown (senses are for the first token only)")
457
+ remove_space = gr.Checkbox(label="Remove space before word", default=False)
458
+ count = gr.Slider(minimum=1, maximum=20, value=10, label="Top K", step=1)
459
+ look_up_button = gr.Button("Look up")
460
+ pos_outputs = gr.Dataframe(label="Highest Scoring Senses")
461
+ neg_outputs = gr.Dataframe(label="Lowest Scoring Senses")
462
+ gr.Examples(
463
+ examples=["science", "afraid", "book", "slight"],
464
+ inputs=[word],
465
+ outputs=[pos_outputs, neg_outputs, token_breakdown],
466
+ fn=visualize_word,
467
+ cache_examples=True,
468
+ )
469
+
470
+ look_up_button.click(
471
+ fn=visualize_word,
472
+ inputs= [word, count, remove_space],
473
+ outputs= [pos_outputs, neg_outputs, token_breakdown],
474
+ )
475
+
476
+ demo.launch(auth=("caesar", "wins"))
477
+
478
+
479
+ # Code for generating slider functions & event listners
480
+
481
+ # for i in range(16):
482
+ # print(
483
+ # f"""def change_sense{i}_weight(contextualization_weights, length, token_index, new_weight):
484
+ # print(f"Changing weight for the {i}th sense of the {{token_index}}th token.")
485
+ # print("new_weight to be assigned = ", new_weight)
486
+ # contextualization_weights[0, {i}, length-1, token_index] = new_weight
487
+ # print("contextualization_weights: ", contextualization_weights[0, :, length-1, token_index])
488
+ # return contextualization_weights"""
489
+ # )
490
+
491
+ # for i in range(16):
492
+ # print(
493
+ # f""" sense{i}slider.change(fn=change_sense{i}_weight,
494
+ # inputs=[contextualization_weights, length, token_index, sense{i}slider],
495
+ # outputs=[contextualization_weights])"""
496
+ # )