Lora
commited on
Commit
·
6d50883
1
Parent(s):
1e64c5d
load vecs and lm head only once
Browse files
app.py
CHANGED
@@ -3,6 +3,10 @@ import pandas as pd
|
|
3 |
import transformers
|
4 |
import gradio as gr
|
5 |
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def visualize_word(word, count=10, remove_space=False):
|
8 |
|
@@ -10,11 +14,6 @@ def visualize_word(word, count=10, remove_space=False):
|
|
10 |
word = ' ' + word
|
11 |
print(f"Looking up word '{word}'...")
|
12 |
|
13 |
-
# very dumb to have to load the tokenizer every time, trying to figure out how to pass a non-interface element into the function in gradio
|
14 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
|
15 |
-
vecs = torch.load("senses/all_vecs_mtx.pt")
|
16 |
-
lm_head = torch.load("senses/lm_head.pt")
|
17 |
-
|
18 |
token_ids = tokenizer(word)['input_ids']
|
19 |
tokens = [tokenizer.decode(token_id) for token_id in token_ids]
|
20 |
tokens = ", ".join(tokens) # display tokenization for user
|
@@ -45,12 +44,14 @@ def visualize_word(word, count=10, remove_space=False):
|
|
45 |
columns=list(data.keys()))
|
46 |
for prop, word_list in data.items():
|
47 |
for i, word_pair in enumerate(word_list):
|
|
|
48 |
cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1])
|
49 |
df.at[i, prop] = cell_value
|
50 |
return df
|
51 |
|
52 |
pos_df = create_dataframe(pos_word_lists, sense_names, count)
|
53 |
neg_df = create_dataframe(neg_word_lists, sense_names, count)
|
|
|
54 |
|
55 |
return pos_df, neg_df, tokens
|
56 |
|
|
|
3 |
import transformers
|
4 |
import gradio as gr
|
5 |
|
6 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
|
7 |
+
vecs = torch.load("senses/all_vecs_mtx.pt")
|
8 |
+
lm_head = torch.load("senses/lm_head.pt")
|
9 |
+
|
10 |
|
11 |
def visualize_word(word, count=10, remove_space=False):
|
12 |
|
|
|
14 |
word = ' ' + word
|
15 |
print(f"Looking up word '{word}'...")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
token_ids = tokenizer(word)['input_ids']
|
18 |
tokens = [tokenizer.decode(token_id) for token_id in token_ids]
|
19 |
tokens = ", ".join(tokens) # display tokenization for user
|
|
|
44 |
columns=list(data.keys()))
|
45 |
for prop, word_list in data.items():
|
46 |
for i, word_pair in enumerate(word_list):
|
47 |
+
cell_value = "space ({:.2f})".format(word_pair[1])
|
48 |
cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1])
|
49 |
df.at[i, prop] = cell_value
|
50 |
return df
|
51 |
|
52 |
pos_df = create_dataframe(pos_word_lists, sense_names, count)
|
53 |
neg_df = create_dataframe(neg_word_lists, sense_names, count)
|
54 |
+
print(pos_df)
|
55 |
|
56 |
return pos_df, neg_df, tokens
|
57 |
|