Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from typing import Dict, List, Tuple, Union | |
| import torch | |
| import pandas | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| from inference_tokenizer import NextSentencePredictionTokenizer | |
| def get_model(model_path): | |
| from transformers import BertForNextSentencePrediction | |
| _model = BertForNextSentencePrediction.from_pretrained(model_path) | |
| _model.eval() | |
| return _model | |
| def get_tokenizer(tokenizer_path): | |
| from transformers import BertTokenizer | |
| tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_path, "tokenizer")) | |
| tokenizer_args = { | |
| "padding": "max_length", | |
| "max_length_ctx": 256, | |
| "max_length_res": 64, | |
| "truncation": "only_first", | |
| "return_tensors": "np", | |
| # will be transfer to tensor later during the training (because of some memory problem with tensors) | |
| "is_split_into_words": True, | |
| } | |
| special_token = " " | |
| # todo better than hardcoded | |
| if tokenizer_path == "./model/e09d71f55f4b6fc20135f856bf029322a3265d8d": | |
| special_token = "[unused1]" | |
| tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) | |
| _inference_tokenizer = NextSentencePredictionTokenizer(tokenizer, special_token=special_token, **tokenizer_args) | |
| return _inference_tokenizer | |
| model_option = st.selectbox( | |
| 'Which model do you want to use?', | |
| ('./model/c3c3bdb7ad80396e69de171995e2038f900940c8', './model/e09d71f55f4b6fc20135f856bf029322a3265d8d')) | |
| model = get_model(model_option) | |
| inference_tokenizer = get_tokenizer(model_option) | |
| def get_evaluation_data(_context: List, special_delimiter=" "): | |
| output_data = [] | |
| for _dict in _context: | |
| _dict: Dict | |
| c = special_delimiter.join(_dict["context"]) | |
| for source in _dict["answers"].values(): | |
| for _t, sentences in source.items(): | |
| for sentence in sentences: | |
| output_data.append([c, sentence, _t]) | |
| return output_data | |
| option = st.selectbox("Choose type of evaluation:", | |
| ["01 - Raw text (one line)", "02 - JSON (aggregated)"]) | |
| with st.form("input_text"): | |
| if "01" in option: | |
| context = st.text_area("Insert context here (sentences divided by ||):") | |
| actual_text = st.text_input("Actual text") | |
| input_tensor = inference_tokenizer.get_item(context=context, actual_sentence=actual_text) | |
| output_model = model(**input_tensor.data).logits | |
| output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0] | |
| prop_follow = output_model[0] | |
| prop_not_follow = output_model[1] | |
| # Every form must have a submit button. | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| fig, ax = plt.subplots() | |
| ax.pie([prop_follow, prop_not_follow], labels=["Probability - Follow", "Probability - Not Follow"], | |
| autopct='%1.1f%%') | |
| st.pyplot(fig) | |
| elif "02" in option: | |
| context = st.text_area("Insert JSON here") | |
| if "{" in context: | |
| evaluation_data = get_evaluation_data(_context=json.loads(context)) | |
| results = [] | |
| accuracy = [] | |
| # Every form must have a submit button. | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| for datapoint in evaluation_data: | |
| c, s, human_label = datapoint | |
| input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s) | |
| output_model = model(**input_tensor.data).logits | |
| output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0] | |
| prop_follow = output_model[0] | |
| prop_not_follow = output_model[1] | |
| results.append((c, s, human_label, prop_follow, prop_not_follow)) | |
| if human_label == "coherent": | |
| accuracy.append(int(prop_follow > prop_not_follow)) | |
| else: | |
| accuracy.append(int(prop_not_follow > prop_follow)) | |
| st.metric(label="Accuracy", value=f"{sum(accuracy) / len(accuracy)} %") | |
| df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)", "Probability (not-follow)"]) | |
| st.dataframe(df) | |