#!/usr/bin/env python # coding: utf-8 import gradio as gr import numpy as np from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline, ) from langdetect import detect from matplotlib import pyplot as plt import imageio # move constants into extra file ML_MODEL_SURE = 0.6 UNKNOWN_LANG_TEXT = ( "The language is not recognized, it must be either in German or in French." ) PLACEHOLDER_TEXT = "Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête." UNSURE_DE_TEXT = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n" UNSURE_FR_TEXT = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n" BARS_DEP_FR = ( "DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF", ) BARS_DEP_DE = ( "VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer", ) def load_model(modelFolder): """Loads model from model_folder & creates a text classification pipeline.""" model = AutoModelForSequenceClassification.from_pretrained(modelFolder) tokenizer = AutoTokenizer.from_pretrained(modelFolder) pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) return pipe def translate_to_de(inputText): """Translates french user input to German for the model to reach better classification.""" translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") translatedText = translator(inputText[0:1000]) text = translatedText[0]["translation_text"] return text def create_bar_plot(rates, language): barnames = BARS_DEP_FR if language == "fr" else BARS_DEP_DE y_pos = np.arange(len(barnames)) plt.barh(y_pos, rates) plt.yticks(y_pos, barnames) # Save the bar chart as png and load it (enables better display) plt.savefig("rates.png") im = imageio.v2.imread("rates.png") return im, barnames def show_chosen_category(barnames, rates, language): """Creates the output text - adds disclaimer if ML model is not sure - when unsure, adds all categories with prob. > 10% to output""" maxRate = np.max(rates) maxIndex = np.argmax(rates) distance = "\t\t\t\t\t" # ML model not sure if highest probability < 60% if maxRate < ML_MODEL_SURE: name = UNSURE_FR_TEXT if language == "fr" else UNSURE_DE_TEXT # Show each department that has a probability > 10% i = 0 while i == 0: if rates[maxIndex] >= 0.1: chosenScore = str(rates[maxIndex])[2:4] chosenCat = barnames[maxIndex] name = name + "\t" + chosenScore + "%" + distance + chosenCat + "\n" rates[maxIndex] = 0 maxIndex = np.argmax(rates) else: i = 1 # ML model pretty sure, show only one department else: name = str(maxRate)[2:4] + "%" + distance + barnames[maxIndex] return name pipeDep = load_model("saved_model_dep") # pipeOffice = load_model("saved_model_office") # Function called by the UI def attribution(inputText): plt.clf() language = detect(inputText) # Translate the input to german if necessary if language == "fr": inputText = translate_to_de(inputText) elif language != "de": return UNKNOWN_LANG_TEXT, None # Make the prediction with the 1000 first characters prediction = pipeDep(inputText[0:1000], return_all_scores=True) rates = [row["score"] for row in prediction[0]] # Create barplot & output text im, barnames = create_bar_plot(rates, language) chosenCategoryText = show_chosen_category(barnames, rates, language) return chosenCategoryText, im # display the UI interface = gr.Interface( fn=attribution, inputs=[gr.components.Textbox(lines=20, placeholder=PLACEHOLDER_TEXT)], outputs=["text", "image"], ) interface.launch()