|
"""Gradio app that showcases Scandinavian zero-shot text classification models.""" |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
from luga import language as detect_language |
|
|
|
|
|
|
|
classifier = pipeline( |
|
"zero-shot-classification", model="alexandrainst/scandi-nli-large" |
|
) |
|
|
|
|
|
|
|
DESCRIPTION = """Classify text in Danish, Swedish or Norwegian into categories, without |
|
finetuning on any training data! |
|
|
|
Note that the models will most likely *not* work as well as a finetuned model on your |
|
specific data, but they can be used as a starting point for your own classification |
|
task ✨ |
|
|
|
|
|
Also, be patient, as this demo is running on a CPU!""" |
|
|
|
|
|
def classification(task: str, doc: str) -> str: |
|
"""Classify text into categories. |
|
|
|
Args: |
|
task (str): |
|
Task to perform. |
|
doc (str): |
|
Text to classify. |
|
|
|
Returns: |
|
str: |
|
The predicted label. |
|
""" |
|
|
|
language = detect_language(doc.replace('\n', ' ')).name |
|
|
|
|
|
if language == "sv" or language == "no": |
|
confidence_str = "konfidensnivå" |
|
else: |
|
confidence_str = "konfidensniveau" |
|
|
|
|
|
if task == "Sentiment classification": |
|
if language == "sv": |
|
hypothesis_template = "Detta exempel är {}." |
|
candidate_labels = ["positivt", "negativt", "neutralt"] |
|
elif language == "no": |
|
hypothesis_template = "Dette eksemplet er {}." |
|
candidate_labels = ["positivt", "negativt", "nøytralt"] |
|
else: |
|
hypothesis_template = "Dette eksempel er {}." |
|
candidate_labels = ["positivt", "negativt", "neutralt"] |
|
|
|
|
|
elif task == "News topic classification": |
|
if language == "sv": |
|
hypothesis_template = "Detta exempel handlar om {}." |
|
candidate_labels = [ |
|
"krig", |
|
"politik", |
|
"utbildning", |
|
"hälsa", |
|
"ekonomi", |
|
"mode", |
|
"sport", |
|
] |
|
elif language == "no": |
|
hypothesis_template = "Dette eksemplet handler om {}." |
|
candidate_labels = [ |
|
"krig", |
|
"politikk", |
|
"utdanning", |
|
"helse", |
|
"økonomi", |
|
"mote", |
|
"sport", |
|
] |
|
else: |
|
hypothesis_template = "Denne nyhedsartikel handler primært om {}." |
|
candidate_labels = [ |
|
"krig", |
|
"politik", |
|
"uddannelse", |
|
"sundhed", |
|
"økonomi", |
|
"mode", |
|
"sport", |
|
] |
|
|
|
|
|
elif task == "Spam detection": |
|
if language == "sv": |
|
hypothesis_template = "Det här e-postmeddelandet ser {}." |
|
candidate_labels = { |
|
"ut som ett skräppostmeddelande": "Spam", |
|
"inte ut som ett skräppostmeddelande": "Inte spam", |
|
} |
|
elif language == "no": |
|
hypothesis_template = "Denne e-posten ser {}." |
|
candidate_labels = { |
|
"ut som en spam-e-post": "Spam", |
|
"ikke ut som en spam-e-post": "Ikke spam", |
|
} |
|
else: |
|
hypothesis_template = "Denne e-mail ligner {}." |
|
candidate_labels = { |
|
"en spam e-mail": "Spam", |
|
"ikke en spam e-mail": "Ikke spam", |
|
} |
|
|
|
|
|
|
|
elif task == "Product feedback detection": |
|
if language == "sv": |
|
hypothesis_template = "Den här kommentaren är {}." |
|
candidate_labels = { |
|
"en recension av en produkt": "Produktfeedback", |
|
"inte en recension av en produkt": "Inte produktfeedback", |
|
} |
|
elif language == "no": |
|
hypothesis_template = "Denne kommentaren er {}." |
|
candidate_labels = { |
|
"en anmeldelse av et produkt": "Produkttilbakemelding", |
|
"ikke en anmeldelse av et produkt": "Ikke produkttilbakemelding", |
|
} |
|
else: |
|
hypothesis_template = "Denne kommentar er {}." |
|
candidate_labels = { |
|
"en anmeldelse af et produkt": "Produktfeedback", |
|
"ikke en anmeldelse af et produkt": "Ikke produktfeedback", |
|
} |
|
|
|
|
|
else: |
|
raise ValueError(f"Task {task} not supported.") |
|
|
|
|
|
|
|
if isinstance(candidate_labels, list): |
|
candidate_labels = {label: label.capitalize() for label in candidate_labels} |
|
|
|
|
|
result = classifier( |
|
doc, |
|
candidate_labels=list(candidate_labels.keys()), |
|
hypothesis_template=hypothesis_template, |
|
) |
|
|
|
print(result) |
|
|
|
|
|
return ( |
|
f"{candidate_labels[result['labels'][0]]}\n" |
|
f"({confidence_str}: {result['scores'][0]:.0%})" |
|
) |
|
|
|
|
|
dropdown = gr.inputs.Dropdown( |
|
label="Task", |
|
choices=[ |
|
"Sentiment classification", |
|
"News topic classification", |
|
"Spam detection", |
|
"Product feedback detection", |
|
], |
|
default="Sentiment classification", |
|
) |
|
|
|
|
|
input_textbox = gr.inputs.Textbox( |
|
label="Text", default="Jeg er helt vild med fodbolden 😊" |
|
) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=classification, |
|
inputs=[dropdown, input_textbox], |
|
outputs=gr.outputs.Label(type="text"), |
|
title="Scandinavian zero-shot text classification", |
|
description=DESCRIPTION, |
|
) |
|
|
|
|
|
interface.launch() |
|
|