File size: 2,472 Bytes
ed60c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Gradio app that showcases Danish offensive text models."""

import gradio as gr
from transformers import pipeline
from shap import Explainer
import numpy as np


def main():
    pipe = pipeline(
        task="text-classification",
        model="alexandrainst/da-offensive-detection-small",
    )

    examples = [
        "Din store idiot.",
        "Jeg er glad for at være her.",
        "Hvem tror du, du er?",
        "Har du hæklefejl i kysen?",
        "Hej med dig, jeg hedder Peter.",
        "Fuck hvor er det dejligt, det her :)",
        "🍆",
        "😊",
    ]

    def classification(text) -> tuple[dict[str, float], dict]:
        output: list[dict] = pipe(text)[0]
        print(output)

        explainer = Explainer(pipe)
        explanation = explainer([text])
        shap_values = explanation.values[0].sum(axis=1)

        # Find the SHAP boundary
        boundary = 0.03
        if np.abs(shap_values).max() <= boundary:
            boundary = np.abs(shap_values).max() - 1e-6

        words: list[str] = explanation.data[0]
        records = list()
        char_idx = 0
        for word, shap_value in zip(words, shap_values):

            if abs(shap_value) <= boundary:
                entity = 'O'
            else:
                entity = output['label'].lower().replace(' ', '-')

            if len(word):
                start = char_idx
                char_idx += len(word)
                end = char_idx
                records.append(dict(
                    entity=entity,
                    word=word,
                    score=abs(shap_value),
                    start=start,
                    end=end,
                ))
        print(list(zip(words, shap_values)))
        print(records)

        return ({output["label"]: output["score"]}, dict(text=text, entities=records))

    color_map = {"offensive": "red", "not-offensive": "green", 'O': 'white'}
    demo = gr.Interface(
        fn=classification,
        inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
        outputs=[gr.Label(), gr.HighlightedText(color_map=color_map)],
        examples=examples,
        title="Danish Offensive Text Detection",
        description="""
Detect offensive text in Danish. Write any text in the box below, and the model will predict whether the text is offensive or not:

_Also, be patient, as this demo is running on a CPU!_""",
    )

    demo.launch()

if __name__ == "__main__":
    main()