Backpack / app.py
Lora
remove auth
440043c
raw
history blame
28.2 kB
import torch
import transformers
from transformers import AutoModelForCausalLM
import pandas as pd
import gradio as gr
# Build model & get some layers
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
m = AutoModelForCausalLM.from_pretrained("lora-x/backpack-gpt2", trust_remote_code=True)
m.eval()
lm_head = m.get_lm_head() # (V, d)
word_embeddings = m.backpack.get_word_embeddings() # (V, d)
sense_network = m.backpack.get_sense_network() # (V, nv, d)
num_senses = m.backpack.get_num_senses()
sense_names = [i for i in range(num_senses)]
"""
Single token sense lookup
"""
def visualize_word(word, count=10, remove_space=False):
if not remove_space:
word = ' ' + word
print(f"Looking up word '{word}'...")
token_ids = tokenizer(word)['input_ids']
tokens = [tokenizer.decode(token_id) for token_id in token_ids]
tokens = ", ".join(tokens) # display tokenization for user
print(f"Tokenized as: {tokens}")
# look up sense vectors only for the first token
# contents = vecs[token_ids[0]] # torch.Size([16, 768])
sense_input_embeds = word_embeddings(torch.tensor([token_ids[0]]).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim
senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d)
senses = torch.squeeze(senses) # (nv, s=1, d)
# for pos and neg respectively, create a list (for each sense) of list (top k) of tuples (word, logit)
pos_word_lists = []
neg_word_lists = []
sense_names = [] # column header
for i in range(senses.shape[0]):
logits = lm_head(senses[i,:])
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sense_names.append('sense {}'.format(i))
pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
pos_sorted_logits = [sorted_logits[j].item() for j in range(count)]
pos_word_lists.append(list(zip(pos_sorted_words, pos_sorted_logits)))
neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(count)]
neg_sorted_logits = [sorted_logits[-j-1].item() for j in range(count)]
neg_word_lists.append(list(zip(neg_sorted_words, neg_sorted_logits)))
def create_dataframe(word_lists, sense_names, count):
data = dict(zip(sense_names, word_lists))
df = pd.DataFrame(index=[i for i in range(count)],
columns=list(data.keys()))
for prop, word_list in data.items():
for i, word_pair in enumerate(word_list):
cell_value = "space ({:.2f})".format(word_pair[1])
cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1])
df.at[i, prop] = cell_value
return df
pos_df = create_dataframe(pos_word_lists, sense_names, count)
neg_df = create_dataframe(neg_word_lists, sense_names, count)
return pos_df, neg_df, tokens
"""
Returns:
- tokens: the tokenization of the input sentence, also used as options to choose from for get_token_contextual_weights
- top_k_words_df: a dataframe of the top k words predicted by the model
- length: of the input sentence, stored as a gr.State variable so other methods can find the
contextualization weights for the *last* token that's needed
- contextualization_weights: gr.State variable, stores the contextualization weights for the input sentence
"""
def predict_next_word (sentence, top_k = 5, contextualization_weights = None):
# For better tokenization, by default, adds a space at the beginning of the sentence if it doesn't already have one
# and remove trailing space
sentence = sentence.strip()
if sentence[0] != ' ':
sentence = ' ' + sentence
print(f"Sentence: '{sentence}'")
# Make input, keeping track of original length
token_ids = tokenizer(sentence)['input_ids']
tokens = [[tokenizer.decode(token_id) for token_id in token_ids]] # a list of a single list because used as dataframe
length = len(token_ids)
inp = torch.zeros((1,512)).long()
inp[0,:length] = torch.tensor(token_ids).long()
# Get output at correct index
if contextualization_weights is None:
print("contextualization_weights IS None, freshly computing contextualization_weights")
output = m(inp)
logits, contextualization_weights = output.logits[0,length-1,:], output.contextualization
# Store contextualization weights and return it as a gr.State var for use by get_token_contextual_weights
else:
print("contextualization_weights is NOT None, using passed in contextualization_weights")
output = m.run_with_custom_contextualization(inp, contextualization_weights)
logits = output.logits[0,length-1,:]
probs = logits.softmax(dim=-1) # probs over next word
probs, indices = torch.sort(probs, descending=True)
top_k_words = [(tokenizer.decode(indices[i]), round(probs[i].item(), 3)) for i in range(top_k)]
top_k_words_df = pd.DataFrame(top_k_words, columns=['word', 'probability'], index=range(1, top_k+1))
top_k_words_df = top_k_words_df.T
print(top_k_words_df)
return tokens, top_k_words_df, length, contextualization_weights
"""
Returns a dataframe of senses with weights for the selected token.
Args:
contextualization_weights: a gr.State variable that stores the contextualization weights for the input sentence.
length: length of the input sentence, used to get the contextualization weights for the last token
token: the selected token
token_index: the index of the selected token in the input sentence
count: how many top words to display for each sense
"""
def get_token_contextual_weights (contextualization_weights, length, token, token_index, count = 7):
print(">>>>>in get_token_contextual_weights")
print(f"Selected {token_index}th token: {token}")
# get contextualization weights for the selected token
# Only care about the weights for the last word, since that's what contributes to the output
token_contextualization_weights = contextualization_weights[0, :, length-1, token_index]
token_contextualization_weights_list = [round(x, 3) for x in token_contextualization_weights.tolist()]
# get sense vectors of the selected token
token_ids = tokenizer(token)['input_ids'] # keep as a list bc sense_network expects s dim
sense_input_embeds = word_embeddings(torch.tensor(token_ids).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim
senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d)
senses = torch.squeeze(senses) # (nv, s=1, d)
# build dataframe
neg_word_lists = []
pos_dfs, neg_dfs = [], []
for i in range(num_senses):
logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
pos_df = pd.DataFrame(pos_sorted_words)
pos_dfs.append(pos_df)
neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(count)]
neg_df = pd.DataFrame(neg_sorted_words)
neg_dfs.append(neg_df)
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
sense12words, sense13words, sense14words, sense15words = pos_dfs
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
return token, token_index, sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, \
sense7words, sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
"""
Wrapper for when the user selects a new token in the tokens dataframe.
Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
"""
def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, count = 7):
print(">>>>>in new_token_contextual_weights")
token_index = evt.index[1] # selected token is the token_index-th token in the sentence
token = evt.value
if not token:
return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return get_token_contextual_weights (contextualization_weights, length, token, token_index, count)
def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 0, length-1, token_index] = new_weight
return contextualization_weights
def change_sense1_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 1, length-1, token_index] = new_weight
return contextualization_weights
def change_sense2_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 2, length-1, token_index] = new_weight
return contextualization_weights
def change_sense3_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 3, length-1, token_index] = new_weight
return contextualization_weights
def change_sense4_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 4, length-1, token_index] = new_weight
return contextualization_weights
def change_sense5_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 5, length-1, token_index] = new_weight
return contextualization_weights
def change_sense6_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 6, length-1, token_index] = new_weight
return contextualization_weights
def change_sense7_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 7, length-1, token_index] = new_weight
return contextualization_weights
def change_sense8_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 8, length-1, token_index] = new_weight
return contextualization_weights
def change_sense9_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 9, length-1, token_index] = new_weight
return contextualization_weights
def change_sense10_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 10, length-1, token_index] = new_weight
return contextualization_weights
def change_sense11_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 11, length-1, token_index] = new_weight
return contextualization_weights
def change_sense12_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 12, length-1, token_index] = new_weight
return contextualization_weights
def change_sense13_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 13, length-1, token_index] = new_weight
return contextualization_weights
def change_sense14_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 14, length-1, token_index] = new_weight
return contextualization_weights
def change_sense15_weight(contextualization_weights, length, token_index, new_weight):
contextualization_weights[0, 15, length-1, token_index] = new_weight
return contextualization_weights
"""
Clears all gr.State variables used to store info across methods when the input sentence changes.
"""
def clear_states(contextualization_weights, token_index, length):
contextualization_weights = None
token_index = None
length = 0
return contextualization_weights, token_index, length
def reset_weights(contextualization_weights):
print("Resetting weights...")
contextualization_weights = None
return contextualization_weights
with gr.Blocks( css = """#sense0slider, #sense1slider, #sense2slider, #sense3slider, #sense4slider, #sense5slider, #sense6slider, #sense7slider,
#sense8slider, #sense9slider, #sense1slider0, #sense11slider, #sense12slider, #sense13slider, #sense14slider, #sense15slider
{ height: 200px; width: 200px; transform: rotate(270deg); }""" ) as demo:
gr.Markdown("""
## Backpack Sense Visualization
""")
with gr.Tab("Language Modeling"):
contextualization_weights = gr.State(None) # store session data for sharing between functions
token_index = gr.State(None)
length = gr.State(0)
top_k = gr.State(10)
with gr.Row():
with gr.Column(scale=8):
input_sentence = gr.Textbox(label="Input Sentence", placeholder='Enter a sentence and click "Predict next word"')
with gr.Column(scale=1):
predict = gr.Button(value="Predict next word", variant="primary")
reset_weights_button = gr.Button("Reset weights")
top_k_words = gr.Dataframe(label="Next Word Predictions (top k)")
gr.Markdown("""### **Tokens:** click on a token to see its senses and contextualization weights""")
tokens = gr.DataFrame(label="")
with gr.Row():
with gr.Column(scale=1):
selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
with gr.Column(scale=8):
gr.Markdown("""#####
Once a token is chosen, you can use the sliders below to change the weights of any senses for that token, \
and then click "Predict next word" to see updated next-word predictions. \
You can change the weights of *multiple senses of multiple tokens*; \
changes will be preserved until you click "Reset weights".
""")
# sense sliders and top sense words dataframes
with gr.Row():
with gr.Column(scale=0, min_width=120):
sense0slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 0", elem_id="sense0slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense1slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 1", elem_id="sense1slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense2slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 2", elem_id="sense2slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense3slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 3", elem_id="sense3slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense4slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 4", elem_id="sense4slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense5slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 5", elem_id="sense5slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense6slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 6", elem_id="sense6slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense7slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 7", elem_id="sense7slider", interactive=True)
with gr.Row():
with gr.Column(scale=0, min_width=120):
sense0words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense1words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense2words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense3words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense4words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense5words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense6words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense7words = gr.DataFrame()
with gr.Row():
with gr.Column(scale=0, min_width=120):
sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense9slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 9", elem_id="sense9slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense10slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 10", elem_id="sense1slider0", interactive=True)
with gr.Column(scale=0, min_width=120):
sense11slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 11", elem_id="sense11slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense12slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 12", elem_id="sense12slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense13slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 13", elem_id="sense13slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense14slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 14", elem_id="sense14slider", interactive=True)
with gr.Column(scale=0, min_width=120):
sense15slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 15", elem_id="sense15slider", interactive=True)
with gr.Row():
with gr.Column(scale=0, min_width=120):
sense8words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense9words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense10words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense11words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense12words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense13words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense14words = gr.DataFrame()
with gr.Column(scale=0, min_width=120):
sense15words = gr.DataFrame()
# gr.Examples(
# examples=[["Messi plays for", top_k, None]],
# inputs=[input_sentence, top_k, contextualization_weights],
# outputs=[tokens, top_k_words, length, contextualization_weights],
# fn=predict_next_word,
# )
sense0slider.change(fn=change_sense0_weight,
inputs=[contextualization_weights, length, token_index, sense0slider],
outputs=[contextualization_weights])
sense1slider.change(fn=change_sense1_weight,
inputs=[contextualization_weights, length, token_index, sense1slider],
outputs=[contextualization_weights])
sense2slider.change(fn=change_sense2_weight,
inputs=[contextualization_weights, length, token_index, sense2slider],
outputs=[contextualization_weights])
sense3slider.change(fn=change_sense3_weight,
inputs=[contextualization_weights, length, token_index, sense3slider],
outputs=[contextualization_weights])
sense4slider.change(fn=change_sense4_weight,
inputs=[contextualization_weights, length, token_index, sense4slider],
outputs=[contextualization_weights])
sense5slider.change(fn=change_sense5_weight,
inputs=[contextualization_weights, length, token_index, sense5slider],
outputs=[contextualization_weights])
sense6slider.change(fn=change_sense6_weight,
inputs=[contextualization_weights, length, token_index, sense6slider],
outputs=[contextualization_weights])
sense7slider.change(fn=change_sense7_weight,
inputs=[contextualization_weights, length, token_index, sense7slider],
outputs=[contextualization_weights])
sense8slider.change(fn=change_sense8_weight,
inputs=[contextualization_weights, length, token_index, sense8slider],
outputs=[contextualization_weights])
sense9slider.change(fn=change_sense9_weight,
inputs=[contextualization_weights, length, token_index, sense9slider],
outputs=[contextualization_weights])
sense10slider.change(fn=change_sense10_weight,
inputs=[contextualization_weights, length, token_index, sense10slider],
outputs=[contextualization_weights])
sense11slider.change(fn=change_sense11_weight,
inputs=[contextualization_weights, length, token_index, sense11slider],
outputs=[contextualization_weights])
sense12slider.change(fn=change_sense12_weight,
inputs=[contextualization_weights, length, token_index, sense12slider],
outputs=[contextualization_weights])
sense13slider.change(fn=change_sense13_weight,
inputs=[contextualization_weights, length, token_index, sense13slider],
outputs=[contextualization_weights])
sense14slider.change(fn=change_sense14_weight,
inputs=[contextualization_weights, length, token_index, sense14slider],
outputs=[contextualization_weights])
sense15slider.change(fn=change_sense15_weight,
inputs=[contextualization_weights, length, token_index, sense15slider],
outputs=[contextualization_weights])
predict.click(
fn=predict_next_word,
inputs = [input_sentence, top_k, contextualization_weights],
outputs= [tokens, top_k_words, length, contextualization_weights],
)
tokens.select(fn=new_token_contextual_weights,
inputs=[contextualization_weights, length],
outputs= [selected_token, token_index,
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
)
reset_weights_button.click(
fn=reset_weights,
inputs=[contextualization_weights],
outputs=[contextualization_weights]
).success(
fn=predict_next_word,
inputs = [input_sentence, top_k, contextualization_weights],
outputs= [tokens, top_k_words, length, contextualization_weights],
).success(
fn=get_token_contextual_weights,
inputs=[contextualization_weights, length, selected_token, token_index],
outputs= [selected_token, token_index,
sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
)
input_sentence.change(
fn=clear_states,
inputs=[contextualization_weights, token_index, length],
outputs=[contextualization_weights, token_index, length]
)
with gr.Tab("Individual Word Sense Look Up"):
gr.Markdown("""> Note on tokenization: Backpack uses the GPT-2 tokenizer, which includes the space before a word as part \
of the token, so by default, a space character `' '` is added to the beginning of the word \
you look up. You can disable this by checking `Remove space before word`, but know this might \
cause strange behaviors like breaking `afraid` into `af` and `raid`, or `slight` into `s` and `light`.
""")
with gr.Row():
word = gr.Textbox(label="Word", placeholder="e.g. science")
token_breakdown = gr.Textbox(label="Token Breakdown (senses are for the first token only)")
remove_space = gr.Checkbox(label="Remove space before word", default=False)
count = gr.Slider(minimum=1, maximum=20, value=10, label="Top K", step=1)
look_up_button = gr.Button("Look up")
pos_outputs = gr.Dataframe(label="Highest Scoring Senses")
neg_outputs = gr.Dataframe(label="Lowest Scoring Senses")
gr.Examples(
examples=["science", "afraid", "book", "slight"],
inputs=[word],
outputs=[pos_outputs, neg_outputs, token_breakdown],
fn=visualize_word,
cache_examples=True,
)
look_up_button.click(
fn=visualize_word,
inputs= [word, count, remove_space],
outputs= [pos_outputs, neg_outputs, token_breakdown],
)
demo.launch()
# Code for generating slider functions & event listners
# for i in range(16):
# print(
# f"""def change_sense{i}_weight(contextualization_weights, length, token_index, new_weight):
# print(f"Changing weight for the {i}th sense of the {{token_index}}th token.")
# print("new_weight to be assigned = ", new_weight)
# contextualization_weights[0, {i}, length-1, token_index] = new_weight
# print("contextualization_weights: ", contextualization_weights[0, :, length-1, token_index])
# return contextualization_weights"""
# )
# for i in range(16):
# print(
# f""" sense{i}slider.change(fn=change_sense{i}_weight,
# inputs=[contextualization_weights, length, token_index, sense{i}slider],
# outputs=[contextualization_weights])"""
# )