|
"""Visualize some sense vectors""" |
|
|
|
import torch |
|
import argparse |
|
|
|
import transformers |
|
|
|
def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None): |
|
""" |
|
Prints out the top-scoring words (and lowest-scoring words) for each sense. |
|
|
|
""" |
|
if contents is None: |
|
print(word) |
|
token_id = tokenizer(word)['input_ids'][0] |
|
contents = vecs[token_id] |
|
|
|
for i in range(contents.shape[0]): |
|
print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i)) |
|
logits = contents[i,:] @ lm_head.t() |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
print('~~~Positive~~~') |
|
for j in range(count): |
|
print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item())) |
|
print('~~~Negative~~~') |
|
for j in range(count): |
|
print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item())) |
|
return contents |
|
print() |
|
print() |
|
print() |
|
|
|
argp = argparse.ArgumentParser() |
|
argp.add_argument('vecs_path') |
|
argp.add_argument('lm_head_path') |
|
args = argp.parse_args() |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') |
|
vecs = torch.load(args.vecs_path) |
|
lm_head = torch.load(args.lm_head_path) |
|
|
|
visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5) |
|
|
|
|