File size: 5,252 Bytes
fddf3ff
 
 
73feb6e
fddf3ff
 
 
 
d832896
 
 
 
 
 
0105d3b
eae7c24
 
 
fddf3ff
 
eae7c24
 
75ebf28
fddf3ff
75ebf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fddf3ff
0105d3b
9b041cc
fddf3ff
0105d3b
 
c7add9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import pandas as pd
import streamlit as st
#from PIL import Image
from streamlit import components
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer

st.title("Hello World!")

st.write(pathlib.Path.home())

st.write(pydicom)

#os.environ["TOKENIZERS_PARALLELISM"] = "false"
#logging.basicConfig(
#    format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
#)


#def print_memory_usage():
#    logging.info(f"RAM memory % used: {psutil.virtual_memory()[2]}")
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)

def load_model(model_name):
    return (
        AutoModelForSequenceClassification.from_pretrained(model_name),
        AutoTokenizer.from_pretrained(model_name),
    )

print ("before main")


st.title("Transformers Interpet Demo App")

print ("before main")

#image = Image.open("./images/tight@1920x_transparent.png")
#st.sidebar.image(image, use_column_width=True)
st.sidebar.markdown(
    "Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
)
st.info(
    "Due to limited resources only low memory models are available. Run this [app locally](https://github.com/cdpierse/transformers-interpret-streamlit) to run the full selection of available models. "
)

# uncomment the options below to test out the app with a variety of classification models.
models = {
    # "textattack/distilbert-base-uncased-rotten-tomatoes": "",
    # "textattack/bert-base-uncased-rotten-tomatoes": "",
    # "textattack/roberta-base-rotten-tomatoes": "",
    # "mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
    # "nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
    "distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
    # "ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
    "sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
    "MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
    # # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
    # "mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
}
model_name = st.sidebar.selectbox(
    "Choose a classification model", list(models.keys())
)
model, tokenizer = load_model(model_name)


print ("Model loaded")
if model_name.startswith("textattack/"):
    model.config.id2label = {0: "NEGATIVE (0) ", 1: "POSITIVE (1)"}
model.eval()
print ("Model Evaluated")
cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
print ("Model Explained")
if cls_explainer.accepts_position_ids:
    emb_type_name = st.sidebar.selectbox(
        "Choose embedding type for attribution.", ["word", "position"]
    )
    if emb_type_name == "word":
        emb_type_num = 0
    if emb_type_name == "position":
        emb_type_num = 1
else:
    emb_type_num = 0

explanation_classes = ["predicted"] + list(model.config.label2id.keys())
explanation_class_choice = st.sidebar.selectbox(
    "Explanation class: The class you would like to explain output with respect to.",
    explanation_classes,
)
my_expander = st.beta_expander(
    "Click here for a description of models and their tasks"
)
with my_expander:
    st.json(models)

# st.info("Max char limit of 350 (memory management)")
text = st.text_area(
    "Enter text to be interpreted",
    "I like you, I love you",
    height=400,
    max_chars=850,
)
print ("Before button")
if st.button('Say hello'):
    st.write('Why hello there')
else:
    st.write('Goodbye')
print ("After test button")

if st.button("Interpret Text"):
    #print_memory_usage()
    st.text("Output")
    with st.spinner("Interpreting your text (This may take some time)"):
        print ("Interpreting text")
        if explanation_class_choice != "predicted":
            word_attributions = cls_explainer(
                text,
                class_name=explanation_class_choice,
                embedding_type=emb_type_num,
                internal_batch_size=2,
            )
        else:
            word_attributions = cls_explainer(
                text, embedding_type=emb_type_num, internal_batch_size=2
            )

    if word_attributions:
        print ("Word Attributions")
        word_attributions_expander = st.beta_expander(
            "Click here for raw word attributions"
        )
        with word_attributions_expander:
            st.json(word_attributions)
        components.v1.html(
            cls_explainer.visualize()._repr_html_(), scrolling=True, height=350
        )
    print ("end of stuff")






print ("end of total file")