Spaces:
Runtime error
Runtime error
FEAT: Code without models
Browse files- .gitignore +6 -0
- app.py +107 -2
- inference_tokenizer.py +34 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
.pyc
|
3 |
+
__pycache__
|
4 |
+
local
|
5 |
+
*wandb*
|
6 |
+
*temp*
|
app.py
CHANGED
@@ -1,4 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import Dict, List, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import pandas
|
7 |
import streamlit as st
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from inference_tokenizer import NextSentencePredictionTokenizer
|
11 |
+
|
12 |
+
|
13 |
+
@st.cache_resource
|
14 |
+
def get_model(model_path):
|
15 |
+
from transformers import BertForNextSentencePrediction
|
16 |
+
_model = BertForNextSentencePrediction.from_pretrained(model_path)
|
17 |
+
_model.eval()
|
18 |
+
return _model
|
19 |
+
|
20 |
+
|
21 |
+
@st.cache_resource
|
22 |
+
def get_tokenizer(tokenizer_path):
|
23 |
+
from transformers import BertTokenizer
|
24 |
+
tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_path, "tokenizer"))
|
25 |
+
tokenizer_args = {
|
26 |
+
"padding": "max_length",
|
27 |
+
"max_length_ctx": 256,
|
28 |
+
"max_length_res": 64,
|
29 |
+
"truncation": "only_first",
|
30 |
+
"return_tensors": "np",
|
31 |
+
# will be transfer to tensor later during the training (because of some memory problem with tensors)
|
32 |
+
"is_split_into_words": True,
|
33 |
+
}
|
34 |
+
special_token = " "
|
35 |
+
# todo better than hardcoded
|
36 |
+
if tokenizer_path == "./model/e09d71f55f4b6fc20135f856bf029322a3265d8d":
|
37 |
+
special_token = "[unused1]"
|
38 |
+
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
39 |
+
_inference_tokenizer = NextSentencePredictionTokenizer(tokenizer, special_token=special_token, **tokenizer_args)
|
40 |
+
return _inference_tokenizer
|
41 |
+
|
42 |
+
|
43 |
+
model_option = st.selectbox(
|
44 |
+
'Which model do you want to use?',
|
45 |
+
('./model/c3c3bdb7ad80396e69de171995e2038f900940c8', './model/e09d71f55f4b6fc20135f856bf029322a3265d8d'))
|
46 |
+
|
47 |
+
model = get_model(model_option)
|
48 |
+
inference_tokenizer = get_tokenizer(model_option)
|
49 |
+
|
50 |
+
|
51 |
+
def get_evaluation_data(_context: List, special_delimiter=" "):
|
52 |
+
output_data = []
|
53 |
+
for _dict in _context:
|
54 |
+
_dict: Dict
|
55 |
+
c = special_delimiter.join(_dict["context"])
|
56 |
+
for source in _dict["answers"].values():
|
57 |
+
for _t, sentences in source.items():
|
58 |
+
for sentence in sentences:
|
59 |
+
output_data.append([c, sentence, _t])
|
60 |
+
return output_data
|
61 |
+
|
62 |
+
|
63 |
+
option = st.selectbox("Choose type of evaluation:",
|
64 |
+
["01 - Raw text (one line)", "02 - JSON (aggregated)"])
|
65 |
+
|
66 |
+
with st.form("input_text"):
|
67 |
+
if "01" in option:
|
68 |
+
context = st.text_area("Insert context here (sentences divided by ||):")
|
69 |
+
actual_text = st.text_input("Actual text")
|
70 |
+
|
71 |
+
input_tensor = inference_tokenizer.get_item(context=context, actual_sentence=actual_text)
|
72 |
+
output_model = model(**input_tensor.data).logits
|
73 |
+
|
74 |
+
output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
|
75 |
+
prop_follow = output_model[0]
|
76 |
+
prop_not_follow = output_model[1]
|
77 |
+
|
78 |
+
# Every form must have a submit button.
|
79 |
+
submitted = st.form_submit_button("Submit")
|
80 |
+
if submitted:
|
81 |
+
fig, ax = plt.subplots()
|
82 |
+
ax.pie([prop_follow, prop_not_follow], labels=["Probability - Follow", "Probability - Not Follow"],
|
83 |
+
autopct='%1.1f%%')
|
84 |
+
st.pyplot(fig)
|
85 |
+
elif "02" in option:
|
86 |
+
context = st.text_area("Insert JSON here")
|
87 |
+
if "{" in context:
|
88 |
+
evaluation_data = get_evaluation_data(_context=json.loads(context))
|
89 |
+
results = []
|
90 |
+
accuracy = []
|
91 |
+
# Every form must have a submit button.
|
92 |
+
submitted = st.form_submit_button("Submit")
|
93 |
+
if submitted:
|
94 |
+
for datapoint in evaluation_data:
|
95 |
+
c, s, human_label = datapoint
|
96 |
+
input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
|
97 |
+
output_model = model(**input_tensor.data).logits
|
98 |
+
output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
|
99 |
+
prop_follow = output_model[0]
|
100 |
+
prop_not_follow = output_model[1]
|
101 |
|
102 |
+
results.append((c, s, human_label, prop_follow, prop_not_follow))
|
103 |
+
if human_label == "coherent":
|
104 |
+
accuracy.append(int(prop_follow > prop_not_follow))
|
105 |
+
else:
|
106 |
+
accuracy.append(int(prop_not_follow > prop_follow))
|
107 |
+
st.metric(label="Accuracy", value=f"{sum(accuracy) / len(accuracy)} %")
|
108 |
+
df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)", "Probability (not-follow)"])
|
109 |
+
st.dataframe(df)
|
inference_tokenizer.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
|
5 |
+
class NextSentencePredictionTokenizer:
|
6 |
+
|
7 |
+
def __init__(self, _tokenizer, special_token, **_tokenizer_args):
|
8 |
+
self.tokenizer = _tokenizer
|
9 |
+
self.tokenizer_args = _tokenizer_args
|
10 |
+
self.max_length_ctx = self.tokenizer_args.get("max_length_ctx")
|
11 |
+
self.max_length_res = self.tokenizer_args.get("max_length_res")
|
12 |
+
del self.tokenizer_args["max_length_ctx"]
|
13 |
+
del self.tokenizer_args["max_length_res"]
|
14 |
+
self.tokenizer_args["max_length"] = self.max_length_ctx + self.max_length_res
|
15 |
+
self.special_token = special_token
|
16 |
+
|
17 |
+
def get_item(self, context: str, actual_sentence: str):
|
18 |
+
actual_item = {"ctx": context.replace("||", self.special_token), "res": actual_sentence}
|
19 |
+
tokenized = self._tokenize_row(actual_item)
|
20 |
+
|
21 |
+
for key in tokenized.data.keys():
|
22 |
+
tokenized.data[key] = torch.reshape(torch.from_numpy(tokenized.data[key]), (1, -1))
|
23 |
+
return tokenized
|
24 |
+
|
25 |
+
def _tokenize_row(self, row: Dict):
|
26 |
+
ctx_tokens = row["ctx"].split(" ")
|
27 |
+
res_tokens = row["res"].split(" ")
|
28 |
+
# -5 for additional information like [SEP], [CLS]
|
29 |
+
ctx_tokens = ctx_tokens[-self.max_length_ctx:]
|
30 |
+
res_tokens = res_tokens[-self.max_length_res:]
|
31 |
+
_args = (ctx_tokens, res_tokens)
|
32 |
+
tokenized_row = self.tokenizer(*_args, **self.tokenizer_args)
|
33 |
+
return tokenized_row
|
34 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
streamlit
|
4 |
+
matplotlib
|
5 |
+
numpy
|
6 |
+
pandas
|