lorenpe2's picture
FEAT: Code without models
c186b27
raw
history blame
4.32 kB
import os
import json
from typing import Dict, List, Tuple, Union
import torch
import pandas
import streamlit as st
import matplotlib.pyplot as plt
from inference_tokenizer import NextSentencePredictionTokenizer
@st.cache_resource
def get_model(model_path):
from transformers import BertForNextSentencePrediction
_model = BertForNextSentencePrediction.from_pretrained(model_path)
_model.eval()
return _model
@st.cache_resource
def get_tokenizer(tokenizer_path):
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_path, "tokenizer"))
tokenizer_args = {
"padding": "max_length",
"max_length_ctx": 256,
"max_length_res": 64,
"truncation": "only_first",
"return_tensors": "np",
# will be transfer to tensor later during the training (because of some memory problem with tensors)
"is_split_into_words": True,
}
special_token = " "
# todo better than hardcoded
if tokenizer_path == "./model/e09d71f55f4b6fc20135f856bf029322a3265d8d":
special_token = "[unused1]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
_inference_tokenizer = NextSentencePredictionTokenizer(tokenizer, special_token=special_token, **tokenizer_args)
return _inference_tokenizer
model_option = st.selectbox(
'Which model do you want to use?',
('./model/c3c3bdb7ad80396e69de171995e2038f900940c8', './model/e09d71f55f4b6fc20135f856bf029322a3265d8d'))
model = get_model(model_option)
inference_tokenizer = get_tokenizer(model_option)
def get_evaluation_data(_context: List, special_delimiter=" "):
output_data = []
for _dict in _context:
_dict: Dict
c = special_delimiter.join(_dict["context"])
for source in _dict["answers"].values():
for _t, sentences in source.items():
for sentence in sentences:
output_data.append([c, sentence, _t])
return output_data
option = st.selectbox("Choose type of evaluation:",
["01 - Raw text (one line)", "02 - JSON (aggregated)"])
with st.form("input_text"):
if "01" in option:
context = st.text_area("Insert context here (sentences divided by ||):")
actual_text = st.text_input("Actual text")
input_tensor = inference_tokenizer.get_item(context=context, actual_sentence=actual_text)
output_model = model(**input_tensor.data).logits
output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
prop_follow = output_model[0]
prop_not_follow = output_model[1]
# Every form must have a submit button.
submitted = st.form_submit_button("Submit")
if submitted:
fig, ax = plt.subplots()
ax.pie([prop_follow, prop_not_follow], labels=["Probability - Follow", "Probability - Not Follow"],
autopct='%1.1f%%')
st.pyplot(fig)
elif "02" in option:
context = st.text_area("Insert JSON here")
if "{" in context:
evaluation_data = get_evaluation_data(_context=json.loads(context))
results = []
accuracy = []
# Every form must have a submit button.
submitted = st.form_submit_button("Submit")
if submitted:
for datapoint in evaluation_data:
c, s, human_label = datapoint
input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
output_model = model(**input_tensor.data).logits
output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
prop_follow = output_model[0]
prop_not_follow = output_model[1]
results.append((c, s, human_label, prop_follow, prop_not_follow))
if human_label == "coherent":
accuracy.append(int(prop_follow > prop_not_follow))
else:
accuracy.append(int(prop_not_follow > prop_follow))
st.metric(label="Accuracy", value=f"{sum(accuracy) / len(accuracy)} %")
df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)", "Probability (not-follow)"])
st.dataframe(df)