ZennyKenny's picture
add dataset selector
80a6e34 verified
raw
history blame
7.11 kB
import gradio as gr
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
matplotlib.use('Agg') # Avoid issues in some remote environments
# Pre-populate a short list of "recommended" Hugging Face datasets
# (Replace "datasorg/iris" etc. with real dataset IDs you want to showcase)
SUGGESTED_DATASETS = [
"datasorg/iris", # hypothetical ID
"uciml/wine_quality-red", # example from the HF Hub
"SKIP/ENTER_CUSTOM" # We'll treat this as a "separator" or "prompt" for custom
]
def load_and_prepare_dataset(dataset_id, label_column, feature_columns):
"""
Loads a dataset from the Hugging Face Hub,
converts it to a pandas DataFrame,
returns X, y as NumPy arrays for modeling.
"""
# Load only the "train" split for simplicity
# Many datasets have "train", "test", "validation" splits
ds = load_dataset(dataset_id, split="train")
# Convert to a DataFrame for easy manipulation
df = pd.DataFrame(ds)
# Subset to selected columns
if label_column not in df.columns:
raise ValueError(f"Label column '{label_column}' not in dataset columns: {df.columns.to_list()}")
for col in feature_columns:
if col not in df.columns:
raise ValueError(f"Feature column '{col}' not in dataset columns: {df.columns.to_list()}")
# Split into X and y
X = df[feature_columns].values
y = df[label_column].values
return X, y, df.columns.tolist()
def train_model(dataset_id, custom_dataset_id, label_column, feature_columns,
learning_rate, n_estimators, max_depth, test_size):
"""
1. Determine final dataset ID (either from dropdown or custom text).
2. Load dataset -> DataFrame -> X, y.
3. Train a GradientBoostingClassifier.
4. Generate plots & metrics (accuracy and confusion matrix).
"""
# Decide which dataset ID to use
if dataset_id != "SKIP/ENTER_CUSTOM":
final_id = dataset_id
else:
# Use the user-supplied "custom_dataset_id"
final_id = custom_dataset_id.strip()
# Prepare data
X, y, columns_available = load_and_prepare_dataset(
final_id,
label_column,
feature_columns
)
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=42
)
# Train model
clf = GradientBoostingClassifier(
learning_rate=learning_rate,
n_estimators=int(n_estimators),
max_depth=int(max_depth),
random_state=42
)
clf.fit(X_train, y_train)
# Evaluate
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
# Plot figure
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
# Subplot 1: Feature Importances
importances = clf.feature_importances_
axs[0].barh(range(len(feature_columns)), importances, color='skyblue')
axs[0].set_yticks(range(len(feature_columns)))
axs[0].set_yticklabels(feature_columns)
axs[0].set_xlabel("Importance")
axs[0].set_title("Feature Importances")
# Subplot 2: Confusion Matrix Heatmap
im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
axs[1].set_title("Confusion Matrix")
plt.colorbar(im, ax=axs[1])
# Labeling
axs[1].set_xlabel("Predicted")
axs[1].set_ylabel("True")
# If you want to annotate each cell:
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
color = "white" if cm[i, j] > thresh else "black"
axs[1].text(j, i, format(cm[i, j], "d"), ha="center", va="center", color=color)
plt.tight_layout()
output_text = f"**Dataset used:** {final_id}\n\n"
output_text += f"**Accuracy:** {accuracy:.3f}\n\n"
output_text += "**Confusion Matrix** (raw counts above)."
return output_text, fig, columns_available
def update_columns(dataset_id, custom_dataset_id):
"""
Callback to dynamically fetch the columns from the dataset
so the user can pick which columns to use as features/labels.
"""
if dataset_id != "SKIP/ENTER_CUSTOM":
final_id = dataset_id
else:
final_id = custom_dataset_id.strip()
# Try to load the dataset and return columns
try:
ds = load_dataset(final_id, split="train")
df = pd.DataFrame(ds)
cols = df.columns.tolist()
# Return as list of selectable options
return gr.update(choices=cols), gr.update(choices=cols), f"Columns found: {cols}"
except Exception as e:
return gr.update(choices=[]), gr.update(choices=[]), f"Error loading {final_id}: {e}"
with gr.Blocks() as demo:
gr.Markdown("## Train GradientBoostingClassifier on a Hugging Face dataset of your choice")
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=SUGGESTED_DATASETS,
value=SUGGESTED_DATASETS[0],
label="Choose a dataset"
)
custom_dataset_id = gr.Textbox(label="Or enter HF dataset (user/dataset)", value="",
placeholder="e.g. 'username/my_custom_dataset'")
# Button to load columns from the chosen dataset
load_cols_btn = gr.Button("Load columns")
load_cols_info = gr.Markdown()
with gr.Row():
label_col = gr.Dropdown(choices=[], label="Label column (choose 1)")
feature_cols = gr.CheckboxGroup(choices=[], label="Feature columns (choose 1 or more)")
# Once columns are chosen, we can set hyperparams
learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
test_size_slider = gr.Slider(0.1, 0.9, value=0.3, step=0.1, label="test_size (fraction)")
train_button = gr.Button("Train & Evaluate")
output_text = gr.Markdown()
output_plot = gr.Plot()
# We might also want to show the columns for reference post-training
columns_return = gr.Markdown()
# When "Load columns" is clicked, we call update_columns to fetch the dataset columns
load_cols_btn.click(
fn=update_columns,
inputs=[dataset_dropdown, custom_dataset_id],
outputs=[label_col, feature_cols, load_cols_info]
)
# When "Train & Evaluate" is clicked, we train the model
train_button.click(
fn=train_model,
inputs=[
dataset_dropdown,
custom_dataset_id,
label_col,
feature_cols,
learning_rate_slider,
n_estimators_slider,
max_depth_slider,
test_size_slider
],
outputs=[output_text, output_plot, columns_return]
)
demo.launch()