bk-departements / app.py
BK-AI's picture
refactor app, prepare for second prediction
4829b64
raw
history blame
4.11 kB
#!/usr/bin/env python
# coding: utf-8
import gradio as gr
import numpy as np
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TextClassificationPipeline,
pipeline,
)
from langdetect import detect
from matplotlib import pyplot as plt
import imageio
# move constants into extra file
ML_MODEL_SURE = 0.6
UNKNOWN_LANG_TEXT = (
"The language is not recognized, it must be either in German or in French."
)
PLACEHOLDER_TEXT = "Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête."
UNSURE_DE_TEXT = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
UNSURE_FR_TEXT = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
BARS_DEP_FR = (
"DDPS",
"DFI",
"AS-MPC",
"DFJP",
"DEFR",
"DETEC",
"DFAE",
"Parl",
"ChF",
"DFF",
"AF",
"TF",
)
BARS_DEP_DE = (
"VBS",
"EDI",
"AB-BA",
"EJPD",
"WBF",
"UVEK",
"EDA",
"Parl",
"BK",
"EFD",
"BV",
"BGer",
)
def load_model(modelFolder):
"""Loads model from model_folder & creates a text classification pipeline."""
model = AutoModelForSequenceClassification.from_pretrained(modelFolder)
tokenizer = AutoTokenizer.from_pretrained(modelFolder)
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
return pipe
def translate_to_de(inputText):
"""Translates french user input to German for the model to reach better classification."""
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
translatedText = translator(inputText[0:1000])
text = translatedText[0]["translation_text"]
return text
def create_bar_plot(rates, language):
barnames = BARS_DEP_FR if language == "fr" else BARS_DEP_DE
y_pos = np.arange(len(barnames))
plt.barh(y_pos, rates)
plt.yticks(y_pos, barnames)
# Save the bar chart as png and load it (enables better display)
plt.savefig("rates.png")
im = imageio.v2.imread("rates.png")
return im, barnames
def show_chosen_category(barnames, rates, language):
"""Creates the output text
- adds disclaimer if ML model is not sure
- when unsure, adds all categories with prob. > 10% to output"""
maxRate = np.max(rates)
maxIndex = np.argmax(rates)
distance = "\t\t\t\t\t"
# ML model not sure if highest probability < 60%
if maxRate < ML_MODEL_SURE:
name = UNSURE_FR_TEXT if language == "fr" else UNSURE_DE_TEXT
# Show each department that has a probability > 10%
i = 0
while i == 0:
if rates[maxIndex] >= 0.1:
chosenScore = str(rates[maxIndex])[2:4]
chosenCat = barnames[maxIndex]
name = name + "\t" + chosenScore + "%" + distance + chosenCat + "\n"
rates[maxIndex] = 0
maxIndex = np.argmax(rates)
else:
i = 1
# ML model pretty sure, show only one department
else:
name = str(maxRate)[2:4] + "%" + distance + barnames[maxIndex]
return name
pipeDep = load_model("saved_model_dep")
# pipeOffice = load_model("saved_model_office")
# Function called by the UI
def attribution(inputText):
plt.clf()
language = detect(inputText)
# Translate the input to german if necessary
if language == "fr":
inputText = translate_to_de(inputText)
elif language != "de":
return UNKNOWN_LANG_TEXT, None
# Make the prediction with the 1000 first characters
prediction = pipeDep(inputText[0:1000], return_all_scores=True)
rates = [row["score"] for row in prediction[0]]
# Create barplot & output text
im, barnames = create_bar_plot(rates, language)
chosenCategoryText = show_chosen_category(barnames, rates, language)
return chosenCategoryText, im
# display the UI
interface = gr.Interface(
fn=attribution,
inputs=[gr.components.Textbox(lines=20, placeholder=PLACEHOLDER_TEXT)],
outputs=["text", "image"],
)
interface.launch()