Commit
·
5ea9bf1
1
Parent(s):
0b41e50
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,8 +20,8 @@ from sentence_transformers import SentenceTransformer, util
|
|
| 20 |
|
| 21 |
#model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
| 22 |
|
| 23 |
-
|
| 24 |
-
model_sts = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
|
| 25 |
#batch_size = 1
|
| 26 |
#scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
|
| 27 |
|
|
@@ -72,11 +72,7 @@ def cloze_prob(text):
|
|
| 72 |
text_list = text.split()
|
| 73 |
stem = ' '.join(text_list[:-1])
|
| 74 |
stem_encoding = tokenizer.encode(stem)
|
| 75 |
-
# cw_encoding is just the difference between whole_text_encoding and stem_encoding
|
| 76 |
-
# note: this might not correspond exactly to the word itself
|
| 77 |
cw_encoding = whole_text_encoding[len(stem_encoding):]
|
| 78 |
-
# Run the entire sentence through the model. Then go "back in time" to look at what the model predicted for each token, starting at the stem.
|
| 79 |
-
# Put the whole text encoding into a tensor, and get the model's comprehensive output
|
| 80 |
tokens_tensor = torch.tensor([whole_text_encoding])
|
| 81 |
|
| 82 |
with torch.no_grad():
|
|
@@ -93,10 +89,7 @@ def cloze_prob(text):
|
|
| 93 |
|
| 94 |
logprobs.append(np.log(softmax(raw_output)))
|
| 95 |
|
| 96 |
-
|
| 97 |
-
# [ [0.412, 0.001, ... ] ,[0.213, 0.004, ...], [0.002,0.001, 0.93 ...]]
|
| 98 |
-
# Then for the i'th token we want to find its associated probability
|
| 99 |
-
# this is just: raw_probabilities[i][token_index]
|
| 100 |
conditional_probs = []
|
| 101 |
for cw,prob in zip(cw_encoding,logprobs):
|
| 102 |
conditional_probs.append(prob[cw])
|
|
|
|
| 20 |
|
| 21 |
#model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
| 22 |
|
| 23 |
+
model_sts = SentenceTransformer('stsb-distilbert-base')
|
| 24 |
+
#model_sts = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
|
| 25 |
#batch_size = 1
|
| 26 |
#scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
|
| 27 |
|
|
|
|
| 72 |
text_list = text.split()
|
| 73 |
stem = ' '.join(text_list[:-1])
|
| 74 |
stem_encoding = tokenizer.encode(stem)
|
|
|
|
|
|
|
| 75 |
cw_encoding = whole_text_encoding[len(stem_encoding):]
|
|
|
|
|
|
|
| 76 |
tokens_tensor = torch.tensor([whole_text_encoding])
|
| 77 |
|
| 78 |
with torch.no_grad():
|
|
|
|
| 89 |
|
| 90 |
logprobs.append(np.log(softmax(raw_output)))
|
| 91 |
|
| 92 |
+
|
|
|
|
|
|
|
|
|
|
| 93 |
conditional_probs = []
|
| 94 |
for cw,prob in zip(cw_encoding,logprobs):
|
| 95 |
conditional_probs.append(prob[cw])
|