File size: 3,563 Bytes
fddf3ff
 
 
 
 
 
 
 
 
 
2e6b9d1
fddf3ff
 
 
0105d3b
eae7c24
 
 
fddf3ff
 
eae7c24
 
fddf3ff
 
0bd968a
0105d3b
fddf3ff
 
 
 
 
 
cac1b0e
fddf3ff
 
0105d3b
 
fddf3ff
0105d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fddf3ff
0105d3b
fddf3ff
0105d3b
 
 
 
fddf3ff
0105d3b
 
c7add9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import asyncio
import gc
import logging
import os

import pandas as pd
import psutil
import streamlit as st
from PIL import Image
from streamlit import components
#from streamlit.caching import clear_cache
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer

#os.environ["TOKENIZERS_PARALLELISM"] = "false"
#logging.basicConfig(
#    format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
#)


#def print_memory_usage():
#    logging.info(f"RAM memory % used: {psutil.virtual_memory()[2]}")


@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)

def load_model(model_name):
    return (
        AutoModelForSequenceClassification.from_pretrained(model_name),
        AutoTokenizer.from_pretrained(model_name),
    )

print ("before main")


st.title("Transformers Interpet Demo App")
print ("before main")

image = Image.open("./images/tight@1920x_transparent.png")
st.sidebar.image(image, use_column_width=True)
st.sidebar.markdown(
    "Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
)
st.info(
    "Due to limited resources only low memory models are available. Run this [app locally](https://github.com/cdpierse/transformers-interpret-streamlit) to run the full selection of available models. "
)

# uncomment the options below to test out the app with a variety of classification models.
models = {
    # "textattack/distilbert-base-uncased-rotten-tomatoes": "",
    # "textattack/bert-base-uncased-rotten-tomatoes": "",
    # "textattack/roberta-base-rotten-tomatoes": "",
    # "mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
    # "nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
    "distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
    # "ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
    "sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
    "MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
    # # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
    # "mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
}
model_name = st.sidebar.selectbox(
    "Choose a classification model", list(models.keys())
)
model, tokenizer = load_model(model_name)
print ("Model loaded")
if model_name.startswith("textattack/"):
    model.config.id2label = {0: "NEGATIVE (0) ", 1: "POSITIVE (1)"}
model.eval()
print ("Model Evaluated")
cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
print ("Model Explained")
if cls_explainer.accepts_position_ids:
    emb_type_name = st.sidebar.selectbox(
        "Choose embedding type for attribution.", ["word", "position"]
    )
    if emb_type_name == "word":
        emb_type_num = 0
    if emb_type_name == "position":
        emb_type_num = 1
else:
    emb_type_num = 0



print ("end of total file")