Spaces:
Runtime error
Runtime error
import os | |
import json | |
from typing import Dict, List, Tuple, Union | |
import torch | |
import pandas | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from inference_tokenizer import NextSentencePredictionTokenizer | |
def get_model(model_path): | |
from transformers import BertForNextSentencePrediction | |
_model = BertForNextSentencePrediction.from_pretrained(model_path) | |
_model.eval() | |
return _model | |
def get_tokenizer(tokenizer_path): | |
from transformers import BertTokenizer | |
tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_path, "tokenizer")) | |
tokenizer_args = { | |
"padding": "max_length", | |
"max_length_ctx": 256, | |
"max_length_res": 64, | |
"truncation": "only_first", | |
"return_tensors": "np", | |
# will be transfer to tensor later during the training (because of some memory problem with tensors) | |
"is_split_into_words": True, | |
} | |
special_token = " " | |
# todo better than hardcoded | |
if tokenizer_path == "./model/e09d71f55f4b6fc20135f856bf029322a3265d8d": | |
special_token = "[unused1]" | |
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) | |
_inference_tokenizer = NextSentencePredictionTokenizer(tokenizer, special_token=special_token, **tokenizer_args) | |
return _inference_tokenizer | |
model_option = st.selectbox( | |
'Which model do you want to use?', | |
('./model/c3c3bdb7ad80396e69de171995e2038f900940c8', './model/e09d71f55f4b6fc20135f856bf029322a3265d8d')) | |
model = get_model(model_option) | |
inference_tokenizer = get_tokenizer(model_option) | |
def get_evaluation_data(_context: List, special_delimiter=" "): | |
output_data = [] | |
for _dict in _context: | |
_dict: Dict | |
c = special_delimiter.join(_dict["context"]) | |
for source in _dict["answers"].values(): | |
for _t, sentences in source.items(): | |
for sentence in sentences: | |
output_data.append([c, sentence, _t]) | |
return output_data | |
option = st.selectbox("Choose type of evaluation:", | |
["01 - Raw text (one line)", "02 - JSON (aggregated)"]) | |
with st.form("input_text"): | |
if "01" in option: | |
context = st.text_area("Insert context here (sentences divided by ||):") | |
actual_text = st.text_input("Actual text") | |
input_tensor = inference_tokenizer.get_item(context=context, actual_sentence=actual_text) | |
output_model = model(**input_tensor.data).logits | |
output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0] | |
prop_follow = output_model[0] | |
prop_not_follow = output_model[1] | |
# Every form must have a submit button. | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
fig, ax = plt.subplots() | |
ax.pie([prop_follow, prop_not_follow], labels=["Probability - Follow", "Probability - Not Follow"], | |
autopct='%1.1f%%') | |
st.pyplot(fig) | |
elif "02" in option: | |
context = st.text_area("Insert JSON here") | |
if "{" in context: | |
evaluation_data = get_evaluation_data(_context=json.loads(context)) | |
results = [] | |
accuracy = [] | |
# Every form must have a submit button. | |
submitted = st.form_submit_button("Submit") | |
if submitted: | |
for datapoint in evaluation_data: | |
c, s, human_label = datapoint | |
input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s) | |
output_model = model(**input_tensor.data).logits | |
output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0] | |
prop_follow = output_model[0] | |
prop_not_follow = output_model[1] | |
results.append((c, s, human_label, prop_follow, prop_not_follow)) | |
if human_label == "coherent": | |
accuracy.append(int(prop_follow > prop_not_follow)) | |
else: | |
accuracy.append(int(prop_not_follow > prop_follow)) | |
st.metric(label="Accuracy", value=f"{sum(accuracy) / len(accuracy)} %") | |
df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)", "Probability (not-follow)"]) | |
st.dataframe(df) | |