rajistics commited on
Commit
0105d3b
·
1 Parent(s): 0bd968a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -72
app.py CHANGED
@@ -12,7 +12,7 @@ from streamlit import components
12
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
13
  from transformers_interpret import SequenceClassificationExplainer
14
 
15
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
  #logging.basicConfig(
17
  # format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
18
  #)
@@ -23,6 +23,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
 
25
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)
 
26
  def load_model(model_name):
27
  return (
28
  AutoModelForSequenceClassification.from_pretrained(model_name),
@@ -30,80 +31,106 @@ def load_model(model_name):
30
  )
31
 
32
  print ("before main")
33
- def main():
34
 
35
- st.title("Transformers Interpet Demo App")
36
- print ("before main")
37
 
38
- image = Image.open("./images/tight@1920x_transparent.png")
39
- st.sidebar.image(image, use_column_width=True)
40
- st.sidebar.markdown(
41
- "Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
42
- )
43
- st.info(
44
- "Due to limited resources only low memory models are available. Run this [app locally](https://github.com/cdpierse/transformers-interpret-streamlit) to run the full selection of available models. "
45
- )
46
 
47
- # uncomment the options below to test out the app with a variety of classification models.
48
- models = {
49
- # "textattack/distilbert-base-uncased-rotten-tomatoes": "",
50
- # "textattack/bert-base-uncased-rotten-tomatoes": "",
51
- # "textattack/roberta-base-rotten-tomatoes": "",
52
- # "mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
53
- # "nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
54
- "distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
55
- # "ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
56
- "sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
57
- "MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
58
- # # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
59
- # "mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
60
- }
61
- model_name = st.sidebar.selectbox(
62
- "Choose a classification model", list(models.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
- model, tokenizer = load_model(model_name)
65
- print ("Model loaded")
66
- if model_name.startswith("textattack/"):
67
- model.config.id2label = {0: "NEGATIVE (0) ", 1: "POSITIVE (1)"}
68
- model.eval()
69
- print ("Model Evaluated")
70
- cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
71
- print ("Model Explained")
72
- if cls_explainer.accepts_position_ids:
73
- emb_type_name = st.sidebar.selectbox(
74
- "Choose embedding type for attribution.", ["word", "position"]
75
- )
76
- if emb_type_name == "word":
77
- emb_type_num = 0
78
- if emb_type_name == "position":
79
- emb_type_num = 1
80
- else:
81
  emb_type_num = 0
 
 
 
 
82
 
83
- explanation_classes = ["predicted"] + list(model.config.label2id.keys())
84
- explanation_class_choice = st.sidebar.selectbox(
85
- "Explanation class: The class you would like to explain output with respect to.",
86
- explanation_classes,
87
- )
88
- my_expander = st.beta_expander(
89
- "Click here for a description of models and their tasks"
90
- )
91
- with my_expander:
92
- st.json(models)
93
-
94
- # st.info("Max char limit of 350 (memory management)")
95
- text = st.text_area(
96
- "Enter text to be interpreted",
97
- "I like you, I love you",
98
- height=400,
99
- max_chars=850,
100
- )
101
- print ("Before button")
102
- if st.button('Say hello'):
103
- st.write('Why hello there')
104
- else:
105
- st.write('Goodbye')
106
- print ("After test button")
107
-
108
- if __name__ == "__main__":
109
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
13
  from transformers_interpret import SequenceClassificationExplainer
14
 
15
+ #os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
  #logging.basicConfig(
17
  # format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
18
  #)
 
23
 
24
 
25
  @st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)
26
+
27
  def load_model(model_name):
28
  return (
29
  AutoModelForSequenceClassification.from_pretrained(model_name),
 
31
  )
32
 
33
  print ("before main")
 
34
 
 
 
35
 
36
+ st.title("Transformers Interpet Demo App")
37
+ print ("before main")
 
 
 
 
 
 
38
 
39
+ image = Image.open("./images/tight@1920x_transparent.png")
40
+ st.sidebar.image(image, use_column_width=True)
41
+ st.sidebar.markdown(
42
+ "Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
43
+ )
44
+ st.info(
45
+ "Due to limited resources only low memory models are available. Run this [app locally](https://github.com/cdpierse/transformers-interpret-streamlit) to run the full selection of available models. "
46
+ )
47
+
48
+ # uncomment the options below to test out the app with a variety of classification models.
49
+ models = {
50
+ # "textattack/distilbert-base-uncased-rotten-tomatoes": "",
51
+ # "textattack/bert-base-uncased-rotten-tomatoes": "",
52
+ # "textattack/roberta-base-rotten-tomatoes": "",
53
+ # "mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
54
+ # "nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
55
+ "distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
56
+ # "ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
57
+ "sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
58
+ "MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
59
+ # # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
60
+ # "mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
61
+ }
62
+ model_name = st.sidebar.selectbox(
63
+ "Choose a classification model", list(models.keys())
64
+ )
65
+ model, tokenizer = load_model(model_name)
66
+ print ("Model loaded")
67
+ if model_name.startswith("textattack/"):
68
+ model.config.id2label = {0: "NEGATIVE (0) ", 1: "POSITIVE (1)"}
69
+ model.eval()
70
+ print ("Model Evaluated")
71
+ cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
72
+ print ("Model Explained")
73
+ if cls_explainer.accepts_position_ids:
74
+ emb_type_name = st.sidebar.selectbox(
75
+ "Choose embedding type for attribution.", ["word", "position"]
76
  )
77
+ if emb_type_name == "word":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  emb_type_num = 0
79
+ if emb_type_name == "position":
80
+ emb_type_num = 1
81
+ else:
82
+ emb_type_num = 0
83
 
84
+ explanation_classes = ["predicted"] + list(model.config.label2id.keys())
85
+ explanation_class_choice = st.sidebar.selectbox(
86
+ "Explanation class: The class you would like to explain output with respect to.",
87
+ explanation_classes,
88
+ )
89
+ my_expander = st.beta_expander(
90
+ "Click here for a description of models and their tasks"
91
+ )
92
+ with my_expander:
93
+ st.json(models)
94
+
95
+ # st.info("Max char limit of 350 (memory management)")
96
+ text = st.text_area(
97
+ "Enter text to be interpreted",
98
+ "I like you, I love you",
99
+ height=400,
100
+ max_chars=850,
101
+ )
102
+ print ("Before button")
103
+ if st.button('Say hello'):
104
+ st.write('Why hello there')
105
+ else:
106
+ st.write('Goodbye')
107
+ print ("After test button")
108
+
109
+ if st.button("Interpret Text"):
110
+ #print_memory_usage()
111
+ st.text("Output")
112
+ with st.spinner("Interpreting your text (This may take some time)"):
113
+ print ("Interpreting text")
114
+ if explanation_class_choice != "predicted":
115
+ word_attributions = cls_explainer(
116
+ text,
117
+ class_name=explanation_class_choice,
118
+ embedding_type=emb_type_num,
119
+ internal_batch_size=2,
120
+ )
121
+ else:
122
+ word_attributions = cls_explainer(
123
+ text, embedding_type=emb_type_num, internal_batch_size=2
124
+ )
125
+
126
+ if word_attributions:
127
+ print ("Word Attributions")
128
+ word_attributions_expander = st.beta_expander(
129
+ "Click here for raw word attributions"
130
+ )
131
+ with word_attributions_expander:
132
+ st.json(word_attributions)
133
+ components.v1.html(
134
+ cls_explainer.visualize()._repr_html_(), scrolling=True, height=350
135
+ )
136
+ print ("end of stuff")