File size: 4,324 Bytes
c186b27
 
 
 
 
 
6457b4b
c186b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6457b4b
c186b27
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import json
from typing import Dict, List, Tuple, Union

import torch
import pandas
import streamlit as st
import matplotlib.pyplot as plt

from inference_tokenizer import NextSentencePredictionTokenizer


@st.cache_resource
def get_model(model_path):
    from transformers import BertForNextSentencePrediction
    _model = BertForNextSentencePrediction.from_pretrained(model_path)
    _model.eval()
    return _model


@st.cache_resource
def get_tokenizer(tokenizer_path):
    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_path, "tokenizer"))
    tokenizer_args = {
        "padding": "max_length",
        "max_length_ctx": 256,
        "max_length_res": 64,
        "truncation": "only_first",
        "return_tensors": "np",
        # will be transfer to tensor later during the training (because of some memory problem with tensors)
        "is_split_into_words": True,
    }
    special_token = " "
    # todo better than hardcoded
    if tokenizer_path == "./model/e09d71f55f4b6fc20135f856bf029322a3265d8d":
        special_token = "[unused1]"
        tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
    _inference_tokenizer = NextSentencePredictionTokenizer(tokenizer, special_token=special_token, **tokenizer_args)
    return _inference_tokenizer


model_option = st.selectbox(
    'Which model do you want to use?',
    ('./model/c3c3bdb7ad80396e69de171995e2038f900940c8', './model/e09d71f55f4b6fc20135f856bf029322a3265d8d'))

model = get_model(model_option)
inference_tokenizer = get_tokenizer(model_option)


def get_evaluation_data(_context: List, special_delimiter=" "):
    output_data = []
    for _dict in _context:
        _dict: Dict
        c = special_delimiter.join(_dict["context"])
        for source in _dict["answers"].values():
            for _t, sentences in source.items():
                for sentence in sentences:
                    output_data.append([c, sentence, _t])
    return output_data


option = st.selectbox("Choose type of evaluation:",
                      ["01 - Raw text (one line)", "02 - JSON (aggregated)"])

with st.form("input_text"):
    if "01" in option:
        context = st.text_area("Insert context here (sentences divided by ||):")
        actual_text = st.text_input("Actual text")

        input_tensor = inference_tokenizer.get_item(context=context, actual_sentence=actual_text)
        output_model = model(**input_tensor.data).logits

        output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
        prop_follow = output_model[0]
        prop_not_follow = output_model[1]

        # Every form must have a submit button.
        submitted = st.form_submit_button("Submit")
        if submitted:
            fig, ax = plt.subplots()
            ax.pie([prop_follow, prop_not_follow], labels=["Probability - Follow", "Probability - Not Follow"],
                   autopct='%1.1f%%')
            st.pyplot(fig)
    elif "02" in option:
        context = st.text_area("Insert JSON here")
        if "{" in context:
            evaluation_data = get_evaluation_data(_context=json.loads(context))
        results = []
        accuracy = []
        # Every form must have a submit button.
        submitted = st.form_submit_button("Submit")
        if submitted:
            for datapoint in evaluation_data:
                c, s, human_label = datapoint
                input_tensor = inference_tokenizer.get_item(context=c, actual_sentence=s)
                output_model = model(**input_tensor.data).logits
                output_model = torch.softmax(output_model, dim=-1).detach().numpy()[0]
                prop_follow = output_model[0]
                prop_not_follow = output_model[1]

                results.append((c, s, human_label, prop_follow, prop_not_follow))
                if human_label == "coherent":
                    accuracy.append(int(prop_follow > prop_not_follow))
                else:
                    accuracy.append(int(prop_not_follow > prop_follow))
            st.metric(label="Accuracy", value=f"{sum(accuracy) / len(accuracy)} %")
            df = pandas.DataFrame(results, columns=["Context", "Query", "Human Label", "Probability (follow)", "Probability (not-follow)"])
            st.dataframe(df)