File size: 5,082 Bytes
7420aa9
 
 
 
 
 
 
 
 
 
 
 
 
41bb40c
 
 
 
 
 
 
 
 
efd38a2
 
7420aa9
 
efd38a2
 
7420aa9
 
 
 
 
efd38a2
7420aa9
 
efd38a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41bb40c
 
efd38a2
41bb40c
 
 
 
 
efd38a2
41bb40c
 
 
 
 
efd38a2
41bb40c
 
 
 
 
efd38a2
 
 
 
7420aa9
41bb40c
 
 
 
 
7420aa9
 
41bb40c
 
 
7420aa9
 
efd38a2
 
7420aa9
efd38a2
41bb40c
efd38a2
 
7420aa9
efd38a2
 
 
41bb40c
efd38a2
 
7420aa9
efd38a2
7420aa9
efd38a2
 
7420aa9
efd38a2
41bb40c
7420aa9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Gradio app that showcases Scandinavian zero-shot text classification models."""

import gradio as gr
from transformers import pipeline
from luga import language as detect_language


# Load the zero-shot classification pipeline
classifier = pipeline(
    "zero-shot-classification", model="alexandrainst/scandi-nli-large"
)


# Set the description for the interface
DESCRIPTION = """Classify text in Danish, Swedish or Norwegian into categories, without
any training data!

Note that the models will most likely *not* work as well as a finetuned model on your
specific data, but they can be used as a starting point for your own classification
task ✨"""


def classification(task: str, doc: str) -> str:
    """Classify text into categories.

    Args:
        task (str):
            Task to perform.
        doc (str):
            Text to classify.

    Returns:
        str:
            The predicted label.
    """
    # Detect the language of the text
    language = detect_language(doc.replace('\n', ' ')).name

    # Define the confidence string based on the language
    if language == "sv" or language == "no":
        confidence_str = "konfidensnivå"
    else:
        confidence_str = "konfidensniveau"

    # If the task is sentiment, classify the text into positive, negative or neutral
    if task == "Sentiment classification":
        if language == "sv":
            hypothesis_template = "Detta exempel är {}."
            candidate_labels = ["positivt", "negativt", "neutralt"]
        elif language == "no":
            hypothesis_template = "Dette eksemplet er {}."
            candidate_labels = ["positivt", "negativt", "nøytralt"]
        else:
            hypothesis_template = "Dette eksempel er {}."
            candidate_labels = ["positivt", "negativt", "neutralt"]

    # Else if the task is topic, classify the text into a topic
    elif task == "News topic classification":
        if language == "sv":
            hypothesis_template = "Detta exempel handlar om {}."
            candidate_labels = [
                "krig",
                "politik",
                "utbildning",
                "hälsa",
                "ekonomi",
                "mode",
                "sport",
            ]
        elif language == "no":
            hypothesis_template = "Dette eksemplet handler om {}."
            candidate_labels = [
                "krig",
                "politikk",
                "utdanning",
                "helse",
                "økonomi",
                "mote",
                "sport",
            ]
        else:
            hypothesis_template = "Denne nyhedsartikel handler primært om {}."
            candidate_labels = [
                "krig",
                "politik",
                "uddannelse",
                "sundhed",
                "økonomi",
                "mode",
                "sport",
            ]

    # Else if the task is spam detection, classify the text into spam or not spam
    elif task == "Spam detection":
        if language == "sv":
            hypothesis_template = "Det här e-postmeddelandet ser {}"
            candidate_labels = {
                "ut som ett skräppostmeddelande": "Spam",
                "inte ut som ett skräppostmeddelande": "Inte spam",
            }
        elif language == "no":
            hypothesis_template = "Denne e-posten ser {}"
            candidate_labels = {
                "ut som en spam-e-post": "Spam",
                "ikke ut som en spam-e-post": "Ikke spam",
            }
        else:
            hypothesis_template = "Denne e-mail ligner {}"
            candidate_labels = {
                "en spam e-mail": "Spam",
                "ikke en spam e-mail": "Ikke spam",
            }

    # Else the task is not supported, so raise an error
    else:
        raise ValueError(f"Task {task} not supported.")

    # If `candidate_labels` is a list then convert it to a dictionary, where the keys
    # are the entries in the list and the values are the keys capitalized
    if isinstance(candidate_labels, list):
        candidate_labels = {label: label.capitalize() for label in candidate_labels}

    # Run the classifier on the text
    result = classifier(
        doc,
        candidate_labels=list(candidate_labels.keys()),
        hypothesis_template=hypothesis_template,
    )

    print(result)

    # Return the predicted label
    return (
        f"{candidate_labels[result['labels'][0]]}\n"
        f"({confidence_str}: {result['scores'][0]:.0%})"
    )

# Create a dropdown menu for the task
dropdown = gr.inputs.Dropdown(
    label="Task",
    choices=["Sentiment classification", "News topic classification", "Spam detection"],
    default="Sentiment classification",
)

# Create the interface, where the function depends on the task chosen
interface = gr.Interface(
    fn=classification,
    inputs=[dropdown, gr.inputs.Textbox(label="Text")],
    outputs=gr.outputs.Label(type="text"),
    title="Scandinavian zero-shot text classification",
    description=DESCRIPTION,
)

# Run the app
interface.launch()