Spaces:
Runtime error
Runtime error
import random | |
import streamlit as st | |
from bs4 import BeautifulSoup | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from transformers import pipeline | |
from transformers_interpret import SequenceClassificationExplainer | |
# Map model names to URLs | |
model_names_to_URLs = { | |
'ml6team/distilbert-base-dutch-cased-toxic-comments': | |
'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments', | |
'ml6team/robbert-dutch-base-toxic-comments': | |
'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments', | |
} | |
about_page_markdown = f"""# π€¬ Dutch Toxic Comment Detection Space | |
Made by [ML6](https://ml6.eu/). | |
Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret). | |
""" | |
regular_emojis = [ | |
'π', 'π', 'πΆ', 'π', | |
] | |
undecided_emojis = [ | |
'π€¨', 'π§', 'π₯Έ', 'π₯΄', 'π€·', | |
] | |
potty_mouth_emojis = [ | |
'π€', 'πΏ', 'π‘', 'π€¬', 'β οΈ', 'β£οΈ', 'β’οΈ', | |
] | |
# Page setup | |
st.set_page_config( | |
page_title="Toxic Comment Detection Space", | |
page_icon="π€¬", | |
layout="centered", | |
initial_sidebar_state="auto", | |
menu_items={ | |
'Get help': None, | |
'Report a bug': None, | |
'About': about_page_markdown, | |
} | |
) | |
# Model setup | |
def load_pipeline(model_name): | |
with st.spinner('Loading model (this might take a while)...'): | |
toxicity_pipeline = pipeline( | |
'text-classification', | |
model=model_name, | |
tokenizer=model_name) | |
cls_explainer = SequenceClassificationExplainer( | |
toxicity_pipeline.model, | |
toxicity_pipeline.tokenizer) | |
return toxicity_pipeline, cls_explainer | |
# Auxiliary functions | |
def format_explainer_html(html_string): | |
"""Extract tokens with attribution-based background color.""" | |
inside_token_prefix = '##' | |
soup = BeautifulSoup(html_string, 'html.parser') | |
p = soup.new_tag('p', | |
attrs={'style': 'color: black; background-color: white;'}) | |
# Select token elements and remove model specific tokens | |
current_word = None | |
for token in soup.find_all('td')[-1].find_all('mark')[1:-1]: | |
text = token.font.text.strip() | |
if text.startswith(inside_token_prefix): | |
text = text[len(inside_token_prefix):] | |
else: | |
# Create a new span for each word (sequence of sub-tokens) | |
if current_word is not None: | |
p.append(current_word) | |
p.append(' ') | |
current_word = soup.new_tag('span') | |
token.string = text | |
token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;" | |
current_word.append(token) | |
# Add last word | |
p.append(current_word) | |
# Add left and right-padding to each word | |
for span in p.find_all('span'): | |
span.find_all('mark')[0].attrs['style'] = ( | |
f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;") | |
span.find_all('mark')[-1].attrs['style'] = ( | |
f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;") | |
return p | |
def classify_comment(comment, selected_model): | |
"""Classify the given comment and augment with additional information.""" | |
toxicity_pipeline, cls_explainer = load_pipeline(selected_model) | |
result = toxicity_pipeline(comment)[0] | |
result['model_name'] = selected_model | |
# Add explanation | |
result['word_attribution'] = cls_explainer(comment, class_name="non-toxic") | |
result['visualitsation_html'] = cls_explainer.visualize()._repr_html_() | |
result['tokens_with_background'] = format_explainer_html( | |
result['visualitsation_html']) | |
# Choose emoji reaction | |
label, score = result['label'], result['score'] | |
if label == 'toxic' and score > 0.1: | |
emoji = random.choice(potty_mouth_emojis) | |
elif label in ['non_toxic', 'non-toxic'] and score > 0.1: | |
emoji = random.choice(regular_emojis) | |
else: | |
emoji = random.choice(undecided_emojis) | |
result.update({'text': comment, 'emoji': emoji}) | |
# Add result to session | |
st.session_state.results.append(result) | |
# Start session | |
if 'results' not in st.session_state: | |
st.session_state.results = [] | |
# Page | |
st.title('π€¬ Dutch Toxic Comment Detection') | |
st.markdown("""This demo showcases two Dutch toxic comment detection models.""") | |
# Introduction | |
st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments. | |
The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""") | |
st.markdown(f"""For a more comprehensive overview of the models check out their model card on π€ Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}). | |
""") | |
st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision: | |
<font color="black"> | |
<span style="background-color: rgb(250, 219, 219); opacity: 1;">r</span><span style="background-color: rgb(244, 179, 179); opacity: 1;">e</span><span style="background-color: rgb(238, 135, 135); opacity: 1;">d</span> | |
</font> | |
tokens indicate toxicity whereas | |
<font color="black"> | |
<span style="background-color: rgb(224, 251, 224); opacity: 1;">g</span><span style="background-color: rgb(197, 247, 197); opacity: 1;">re</span><span style="background-color: rgb(121, 236, 121); opacity: 1;">en</span> | |
</font> tokens indicate the opposite. | |
Try it yourself! π""", | |
unsafe_allow_html=True) | |
# Demo | |
with st.form("dutch-toxic-comment-detection-input", clear_on_submit=False): | |
selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(), | |
)#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False) | |
text = st.text_area( | |
label='Enter the comment you want to classify below (in Dutch):') | |
_, rightmost_col = st.columns([6,1]) | |
submitted = rightmost_col.form_submit_button("Classify", | |
help="Classify comment") | |
# Listener | |
if submitted: | |
if text: | |
with st.spinner('Analysing comment...'): | |
classify_comment(text, selected_model) | |
else: | |
st.error('**Error**: No comment to classify. Please provide a comment.') | |
# Results | |
if 'results' in st.session_state and st.session_state.results: | |
first = True | |
for result in st.session_state.results[::-1]: | |
if not first: | |
st.markdown("---") | |
st.markdown(f"Text:\n> {result['text']}") | |
col_1, col_2, col_3 = st.columns([1,2,2]) | |
col_1.metric(label='', value=f"{result['emoji']}") | |
col_2.metric(label='Label', value=f"{result['label']}") | |
col_3.metric(label='Score', value=f"{result['score']:.3f}") | |
st.markdown(f"Token Attribution:\n{result['tokens_with_background']}", | |
unsafe_allow_html=True) | |
st.caption(f"Model: {result['model_name']}") | |
first = False | |