Spaces:
Runtime error
Runtime error
File size: 3,563 Bytes
fddf3ff 2e6b9d1 fddf3ff 0105d3b eae7c24 fddf3ff eae7c24 fddf3ff 0bd968a 0105d3b fddf3ff cac1b0e fddf3ff 0105d3b fddf3ff 0105d3b fddf3ff 0105d3b fddf3ff 0105d3b fddf3ff 0105d3b c7add9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import asyncio
import gc
import logging
import os
import pandas as pd
import psutil
import streamlit as st
from PIL import Image
from streamlit import components
#from streamlit.caching import clear_cache
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
#os.environ["TOKENIZERS_PARALLELISM"] = "false"
#logging.basicConfig(
# format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
#)
#def print_memory_usage():
# logging.info(f"RAM memory % used: {psutil.virtual_memory()[2]}")
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=1)
def load_model(model_name):
return (
AutoModelForSequenceClassification.from_pretrained(model_name),
AutoTokenizer.from_pretrained(model_name),
)
print ("before main")
st.title("Transformers Interpet Demo App")
print ("before main")
image = Image.open("./images/tight@1920x_transparent.png")
st.sidebar.image(image, use_column_width=True)
st.sidebar.markdown(
"Check out the package on [Github](https://github.com/cdpierse/transformers-interpret)"
)
st.info(
"Due to limited resources only low memory models are available. Run this [app locally](https://github.com/cdpierse/transformers-interpret-streamlit) to run the full selection of available models. "
)
# uncomment the options below to test out the app with a variety of classification models.
models = {
# "textattack/distilbert-base-uncased-rotten-tomatoes": "",
# "textattack/bert-base-uncased-rotten-tomatoes": "",
# "textattack/roberta-base-rotten-tomatoes": "",
# "mrm8488/bert-mini-finetuned-age_news-classification": "BERT-Mini finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
# "nateraw/bert-base-uncased-ag-news": "BERT finetuned on AG News dataset. Predicts news class (sports/tech/business/world) of text.",
"distilbert-base-uncased-finetuned-sst-2-english": "DistilBERT model finetuned on SST-2 sentiment analysis task. Predicts positive/negative sentiment.",
# "ProsusAI/finbert": "BERT model finetuned to predict sentiment of financial text. Finetuned on Financial PhraseBank data. Predicts positive/negative/neutral.",
"sampathkethineedi/industry-classification": "DistilBERT Model to classify a business description into one of 62 industry tags.",
"MoritzLaurer/policy-distilbert-7d": "DistilBERT model finetuned to classify text into one of seven political categories.",
# # "MoritzLaurer/covid-policy-roberta-21": "(Under active development ) RoBERTA model finetuned to identify COVID policy measure classes ",
# "mrm8488/bert-tiny-finetuned-sms-spam-detection": "Tiny bert model finetuned for spam detection. 0 == not spam, 1 == spam",
}
model_name = st.sidebar.selectbox(
"Choose a classification model", list(models.keys())
)
model, tokenizer = load_model(model_name)
print ("Model loaded")
if model_name.startswith("textattack/"):
model.config.id2label = {0: "NEGATIVE (0) ", 1: "POSITIVE (1)"}
model.eval()
print ("Model Evaluated")
cls_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
print ("Model Explained")
if cls_explainer.accepts_position_ids:
emb_type_name = st.sidebar.selectbox(
"Choose embedding type for attribution.", ["word", "position"]
)
if emb_type_name == "word":
emb_type_num = 0
if emb_type_name == "position":
emb_type_num = 1
else:
emb_type_num = 0
print ("end of total file") |