|
"""Gradio app that showcases Scandinavian zero-shot text classification models.""" |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
from luga import language as detect_language |
|
|
|
|
|
|
|
classifier = pipeline( |
|
"zero-shot-classification", model="alexandrainst/scandi-nli-large" |
|
) |
|
|
|
|
|
|
|
DESCRIPTION = """Classify text in Danish, Swedish or Norwegian into categories, without |
|
any training data! |
|
|
|
Note that the models will most likely *not* work as well as a finetuned model on your |
|
specific data, but they can be used as a starting point for your own classification |
|
task ✨ |
|
|
|
|
|
Also, be patient, as this demo is running on a CPU!""" |
|
|
|
|
|
def classification(task: str, doc: str) -> str: |
|
"""Classify text into categories. |
|
|
|
Args: |
|
task (str): |
|
Task to perform. |
|
doc (str): |
|
Text to classify. |
|
|
|
Returns: |
|
str: |
|
The predicted label. |
|
""" |
|
|
|
language = detect_language(doc.replace('\n', ' ')).name |
|
|
|
|
|
if language == "sv" or language == "no": |
|
confidence_str = "konfidensnivå" |
|
else: |
|
confidence_str = "konfidensniveau" |
|
|
|
|
|
if task == "Sentiment classification": |
|
if language == "sv": |
|
hypothesis_template = "Detta exempel är {}." |
|
candidate_labels = ["positivt", "negativt", "neutralt"] |
|
elif language == "no": |
|
hypothesis_template = "Dette eksemplet er {}." |
|
candidate_labels = ["positivt", "negativt", "nøytralt"] |
|
else: |
|
hypothesis_template = "Dette eksempel er {}." |
|
candidate_labels = ["positivt", "negativt", "neutralt"] |
|
|
|
|
|
elif task == "News topic classification": |
|
if language == "sv": |
|
hypothesis_template = "Detta exempel handlar om {}." |
|
candidate_labels = [ |
|
"krig", |
|
"politik", |
|
"utbildning", |
|
"hälsa", |
|
"ekonomi", |
|
"mode", |
|
"sport", |
|
] |
|
elif language == "no": |
|
hypothesis_template = "Dette eksemplet handler om {}." |
|
candidate_labels = [ |
|
"krig", |
|
"politikk", |
|
"utdanning", |
|
"helse", |
|
"økonomi", |
|
"mote", |
|
"sport", |
|
] |
|
else: |
|
hypothesis_template = "Denne nyhedsartikel handler primært om {}." |
|
candidate_labels = [ |
|
"krig", |
|
"politik", |
|
"uddannelse", |
|
"sundhed", |
|
"økonomi", |
|
"mode", |
|
"sport", |
|
] |
|
|
|
|
|
elif task == "Spam detection": |
|
if language == "sv": |
|
hypothesis_template = "Det här e-postmeddelandet ser {}" |
|
candidate_labels = { |
|
"ut som ett skräppostmeddelande": "Spam", |
|
"inte ut som ett skräppostmeddelande": "Inte spam", |
|
} |
|
elif language == "no": |
|
hypothesis_template = "Denne e-posten ser {}" |
|
candidate_labels = { |
|
"ut som en spam-e-post": "Spam", |
|
"ikke ut som en spam-e-post": "Ikke spam", |
|
} |
|
else: |
|
hypothesis_template = "Denne e-mail ligner {}" |
|
candidate_labels = { |
|
"en spam e-mail": "Spam", |
|
"ikke en spam e-mail": "Ikke spam", |
|
} |
|
|
|
|
|
else: |
|
raise ValueError(f"Task {task} not supported.") |
|
|
|
|
|
|
|
if isinstance(candidate_labels, list): |
|
candidate_labels = {label: label.capitalize() for label in candidate_labels} |
|
|
|
|
|
result = classifier( |
|
doc, |
|
candidate_labels=list(candidate_labels.keys()), |
|
hypothesis_template=hypothesis_template, |
|
) |
|
|
|
print(result) |
|
|
|
|
|
return ( |
|
f"{candidate_labels[result['labels'][0]]}\n" |
|
f"({confidence_str}: {result['scores'][0]:.0%})" |
|
) |
|
|
|
|
|
dropdown = gr.inputs.Dropdown( |
|
label="Task", |
|
choices=["Sentiment classification", "News topic classification", "Spam detection"], |
|
default="Sentiment classification", |
|
) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=classification, |
|
inputs=[dropdown, gr.inputs.Textbox(label="Text")], |
|
outputs=gr.outputs.Label(type="text"), |
|
title="Scandinavian zero-shot text classification", |
|
description=DESCRIPTION, |
|
) |
|
|
|
|
|
interface.launch() |
|
|