Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import gradio as gr | |
import numpy as np | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
TextClassificationPipeline, | |
pipeline, | |
) | |
from langdetect import detect | |
from matplotlib import pyplot as plt | |
import imageio | |
# move constants into extra file | |
ML_MODEL_SURE = 0.6 | |
UNKNOWN_LANG_TEXT = ( | |
"The language is not recognized, it must be either in German or in French." | |
) | |
PLACEHOLDER_TEXT = "Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête." | |
UNSURE_DE_TEXT = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n" | |
UNSURE_FR_TEXT = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n" | |
BARS_DEP_FR = ( | |
"DDPS", | |
"DFI", | |
"AS-MPC", | |
"DFJP", | |
"DEFR", | |
"DETEC", | |
"DFAE", | |
"Parl", | |
"ChF", | |
"DFF", | |
"AF", | |
"TF", | |
) | |
BARS_DEP_DE = ( | |
"VBS", | |
"EDI", | |
"AB-BA", | |
"EJPD", | |
"WBF", | |
"UVEK", | |
"EDA", | |
"Parl", | |
"BK", | |
"EFD", | |
"BV", | |
"BGer", | |
) | |
def load_model(modelFolder): | |
"""Loads model from model_folder & creates a text classification pipeline.""" | |
model = AutoModelForSequenceClassification.from_pretrained(modelFolder) | |
tokenizer = AutoTokenizer.from_pretrained(modelFolder) | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) | |
return pipe | |
def translate_to_de(inputText): | |
"""Translates french user input to German for the model to reach better classification.""" | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") | |
translatedText = translator(inputText[0:1000]) | |
text = translatedText[0]["translation_text"] | |
return text | |
def create_bar_plot(rates, language): | |
barnames = BARS_DEP_FR if language == "fr" else BARS_DEP_DE | |
y_pos = np.arange(len(barnames)) | |
plt.barh(y_pos, rates) | |
plt.yticks(y_pos, barnames) | |
# Save the bar chart as png and load it (enables better display) | |
plt.savefig("rates.png") | |
im = imageio.v2.imread("rates.png") | |
return im, barnames | |
def show_chosen_category(barnames, rates, language): | |
"""Creates the output text | |
- adds disclaimer if ML model is not sure | |
- when unsure, adds all categories with prob. > 10% to output""" | |
maxRate = np.max(rates) | |
maxIndex = np.argmax(rates) | |
distance = "\t\t\t\t\t" | |
# ML model not sure if highest probability < 60% | |
if maxRate < ML_MODEL_SURE: | |
name = UNSURE_FR_TEXT if language == "fr" else UNSURE_DE_TEXT | |
# Show each department that has a probability > 10% | |
i = 0 | |
while i == 0: | |
if rates[maxIndex] >= 0.1: | |
chosenScore = str(rates[maxIndex])[2:4] | |
chosenCat = barnames[maxIndex] | |
name = name + "\t" + chosenScore + "%" + distance + chosenCat + "\n" | |
rates[maxIndex] = 0 | |
maxIndex = np.argmax(rates) | |
else: | |
i = 1 | |
# ML model pretty sure, show only one department | |
else: | |
name = str(maxRate)[2:4] + "%" + distance + barnames[maxIndex] | |
return name | |
pipeDep = load_model("saved_model_dep") | |
# pipeOffice = load_model("saved_model_office") | |
# Function called by the UI | |
def attribution(inputText): | |
plt.clf() | |
language = detect(inputText) | |
# Translate the input to german if necessary | |
if language == "fr": | |
inputText = translate_to_de(inputText) | |
elif language != "de": | |
return UNKNOWN_LANG_TEXT, None | |
# Make the prediction with the 1000 first characters | |
prediction = pipeDep(inputText[0:1000], return_all_scores=True) | |
rates = [row["score"] for row in prediction[0]] | |
# Create barplot & output text | |
im, barnames = create_bar_plot(rates, language) | |
chosenCategoryText = show_chosen_category(barnames, rates, language) | |
return chosenCategoryText, im | |
# display the UI | |
interface = gr.Interface( | |
fn=attribution, | |
inputs=[gr.components.Textbox(lines=20, placeholder=PLACEHOLDER_TEXT)], | |
outputs=["text", "image"], | |
) | |
interface.launch() | |