Lora commited on
Commit
6d50883
·
1 Parent(s): 1e64c5d

load vecs and lm head only once

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -3,6 +3,10 @@ import pandas as pd
3
  import transformers
4
  import gradio as gr
5
 
 
 
 
 
6
 
7
  def visualize_word(word, count=10, remove_space=False):
8
 
@@ -10,11 +14,6 @@ def visualize_word(word, count=10, remove_space=False):
10
  word = ' ' + word
11
  print(f"Looking up word '{word}'...")
12
 
13
- # very dumb to have to load the tokenizer every time, trying to figure out how to pass a non-interface element into the function in gradio
14
- tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
15
- vecs = torch.load("senses/all_vecs_mtx.pt")
16
- lm_head = torch.load("senses/lm_head.pt")
17
-
18
  token_ids = tokenizer(word)['input_ids']
19
  tokens = [tokenizer.decode(token_id) for token_id in token_ids]
20
  tokens = ", ".join(tokens) # display tokenization for user
@@ -45,12 +44,14 @@ def visualize_word(word, count=10, remove_space=False):
45
  columns=list(data.keys()))
46
  for prop, word_list in data.items():
47
  for i, word_pair in enumerate(word_list):
 
48
  cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1])
49
  df.at[i, prop] = cell_value
50
  return df
51
 
52
  pos_df = create_dataframe(pos_word_lists, sense_names, count)
53
  neg_df = create_dataframe(neg_word_lists, sense_names, count)
 
54
 
55
  return pos_df, neg_df, tokens
56
 
 
3
  import transformers
4
  import gradio as gr
5
 
6
+ tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
7
+ vecs = torch.load("senses/all_vecs_mtx.pt")
8
+ lm_head = torch.load("senses/lm_head.pt")
9
+
10
 
11
  def visualize_word(word, count=10, remove_space=False):
12
 
 
14
  word = ' ' + word
15
  print(f"Looking up word '{word}'...")
16
 
 
 
 
 
 
17
  token_ids = tokenizer(word)['input_ids']
18
  tokens = [tokenizer.decode(token_id) for token_id in token_ids]
19
  tokens = ", ".join(tokens) # display tokenization for user
 
44
  columns=list(data.keys()))
45
  for prop, word_list in data.items():
46
  for i, word_pair in enumerate(word_list):
47
+ cell_value = "space ({:.2f})".format(word_pair[1])
48
  cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1])
49
  df.at[i, prop] = cell_value
50
  return df
51
 
52
  pos_df = create_dataframe(pos_word_lists, sense_names, count)
53
  neg_df = create_dataframe(neg_word_lists, sense_names, count)
54
+ print(pos_df)
55
 
56
  return pos_df, neg_df, tokens
57