"""Gradio app that showcases Danish offensive text models.""" import warnings from numba.core.errors import NumbaDeprecationWarning warnings.filterwarnings("ignore", category=NumbaDeprecationWarning) import gradio as gr from transformers import pipeline from shap import Explainer import numpy as np from typing import Tuple, Dict, List def main(): pipe = pipeline( task="text-classification", model="alexandrainst/da-offensive-detection-small", ) examples = [ "Din store idiot.", "Jeg er glad for at være her.", "Hvem tror du, du er?", "Har du hæklefejl i kysen?", "Hej med dig, jeg hedder Peter.", "Fuck hvor er det dejligt, det her :)", "🍆", "😊", ] def classification(text) -> Tuple[Dict[str, float], dict]: output: List[dict] = pipe(text)[0] print(output) explainer = Explainer(pipe) explanation = explainer([text]) shap_values = explanation.values[0].sum(axis=1) # Find the SHAP boundary boundary = 0.03 if np.abs(shap_values).max() <= boundary: boundary = np.abs(shap_values).max() - 1e-6 words: List[str] = explanation.data[0] records = list() char_idx = 0 for word, shap_value in zip(words, shap_values): if abs(shap_value) <= boundary: entity = 'O' else: entity = output['label'].lower().replace(' ', '-') if len(word): start = char_idx char_idx += len(word) end = char_idx records.append(dict( entity=entity, word=word, score=abs(shap_value), start=start, end=end, )) print(records) return ({output["label"]: output["score"]}, dict(text=text, entities=records)) color_map = {"offensive": "red", "not-offensive": "green", 'O': 'white'} demo = gr.Interface( fn=classification, inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]), outputs=[gr.Label(), gr.HighlightedText().style(color_map=color_map)], examples=examples, title="Danish Offensive Text Detection", description=""" Detect offensive text in Danish. Write any text in the box below, and the model will predict whether the text is offensive or not: _Also, be patient, as this demo is running on a CPU!_""", ) demo.launch() if __name__ == "__main__": main()