Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.pipeline import Pipeline | |
# 1. Laden und Vorbereiten des Datensatzes (einmalig beim Start) | |
try: | |
dataset = load_dataset("banking77") | |
texts = dataset['train']['text'] + dataset['test']['text'] | |
labels = dataset['train']['label'] + dataset['test']['label'] | |
label_encoder = LabelEncoder() | |
numerical_labels = label_encoder.fit_transform(labels) | |
label_names = label_encoder.classes_ | |
train_texts, test_texts, train_labels, test_labels = train_test_split( | |
texts, numerical_labels, test_size=0.2, random_state=42, stratify=numerical_labels | |
) | |
except Exception as e: | |
print(f"Fehler beim Laden des Datensatzes: {e}") | |
label_names = ["Fehler beim Laden"] | |
pipeline = None | |
# 2. Trainieren des Modells (einmalig beim Start) | |
if 'pipeline' not in locals() or pipeline is None: | |
try: | |
pipeline = Pipeline([ | |
('tfidf', TfidfVectorizer()), | |
('classifier', LogisticRegression(solver='liblinear', multi_class='ovr', random_state=42)) | |
]) | |
pipeline.fit(train_texts, train_labels) | |
print("Modell erfolgreich trainiert.") | |
except Exception as e: | |
print(f"Fehler beim Trainieren des Modells: {e}") | |
pipeline = None | |
# 3. Funktion für die Vorhersage | |
def predict_intent(text): | |
if pipeline is not None and label_names: | |
prediction = pipeline.predict([text])[0] | |
predicted_label = label_names[prediction] | |
probabilities = pipeline.predict_proba([text])[0] | |
confidences = {label_names[i]: f"{probabilities[i]:.2f}" for i in range(len(label_names))} | |
return predicted_label, confidences | |
else: | |
return "Fehler", {"Fehler": "Modell nicht geladen oder trainiert."} | |
# 4. Erstellen der Gradio Interface | |
iface = gr.Interface( | |
fn=predict_intent, | |
inputs=gr.Textbox(label="Gib deine Kundenanfrage ein:"), | |
outputs=[ | |
gr.Label(label="Vorhergesagte Kundenintention:"), | |
gr.JSON(label="Konfidenzwerte:") | |
], | |
title="Vorhersage der Kundenintention (Banking77)", | |
description="Dieses Demo sagt die Kundenintention basierend auf der Eingabe einer Textanfrage vorher. Das Modell wurde auf dem Banking77-Datensatz trainiert.", | |
examples=[ | |
["Ich habe mein Passwort vergessen."], | |
["Wie kann ich Geld überweisen?"], | |
["Meine Karte ist verloren gegangen."], | |
["Was ist der aktuelle Zinssatz für ein Sparkonto?"] | |
] | |
) | |
# 5. Starten der Gradio App (wird beim Ausführen des Skripts aktiv) | |
iface.launch(share=False) |