File size: 5,121 Bytes
96b98f3 e252d2c 43728f4 96b98f3 e252d2c 96b98f3 e252d2c 96b98f3 e252d2c 96b98f3 e252d2c 43728f4 e252d2c 96b98f3 e252d2c 96b98f3 43728f4 e252d2c 96b98f3 e252d2c 43728f4 e252d2c 96b98f3 edfc8c7 e252d2c edfc8c7 e252d2c edfc8c7 43728f4 edfc8c7 e252d2c edfc8c7 e252d2c edfc8c7 e252d2c 43728f4 e252d2c edfc8c7 96b98f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
# This line ensures Matplotlib doesn't try to open windows in certain environments:
matplotlib.use('Agg')
# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
class_names = iris.target_names
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
def train_and_evaluate(learning_rate, n_estimators, max_depth):
# Train the model
clf = GradientBoostingClassifier(
learning_rate=learning_rate,
n_estimators=n_estimators,
max_depth=int(max_depth),
random_state=42
)
clf.fit(X_train, y_train)
# Predict on test set
y_pred = clf.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
# Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)
# Create a single figure with 2 subplots
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))
# --- Subplot 1: Feature Importances ---
importances = clf.feature_importances_
axs[0].barh(range(len(feature_names)), importances, color='skyblue')
axs[0].set_yticks(range(len(feature_names)))
axs[0].set_yticklabels(feature_names)
axs[0].set_xlabel("Importance")
axs[0].set_title("Feature Importances")
# --- Subplot 2: Confusion Matrix Heatmap ---
im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
axs[1].set_title("Confusion Matrix")
# Add colorbar
cbar = fig.colorbar(im, ax=axs[1])
# Tick marks for x/y axes
axs[1].set_xticks(range(len(class_names)))
axs[1].set_yticks(range(len(class_names)))
axs[1].set_xticklabels(class_names, rotation=45, ha="right")
axs[1].set_yticklabels(class_names)
axs[1].set_ylabel('True Label')
axs[1].set_xlabel('Predicted Label')
# Write the counts in each cell
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
color = "white" if cm[i, j] > thresh else "black"
axs[1].text(j, i, format(cm[i, j], "d"),
ha="center", va="center", color=color)
plt.tight_layout()
# Return textual results + the figure
results_text = f"Accuracy: {accuracy:.3f}"
return results_text, fig
def predict_species(sepal_length, sepal_width, petal_length, petal_width,
learning_rate, n_estimators, max_depth):
clf = GradientBoostingClassifier(
learning_rate=learning_rate,
n_estimators=n_estimators,
max_depth=int(max_depth),
random_state=42
)
clf.fit(X_train, y_train)
user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
prediction = clf.predict(user_sample)[0]
return f"Predicted species: {class_names[prediction]}"
with gr.Blocks() as demo:
with gr.Tab("Train & Evaluate"):
gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
train_button = gr.Button("Train & Evaluate")
output_text = gr.Textbox(label="Results")
output_plot = gr.Plot(label="Feature Importances & Confusion Matrix")
train_button.click(
fn=train_and_evaluate,
inputs=[learning_rate_slider, n_estimators_slider, max_depth_slider],
outputs=[output_text, output_plot],
)
with gr.Tab("Predict"):
gr.Markdown("## Predict Iris Species with GradientBoostingClassifier")
sepal_length_input = gr.Number(value=5.1, label=feature_names[0])
sepal_width_input = gr.Number(value=3.5, label=feature_names[1])
petal_length_input = gr.Number(value=1.4, label=feature_names[2])
petal_width_input = gr.Number(value=0.2, label=feature_names[3])
learning_rate_slider2 = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
n_estimators_slider2 = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
max_depth_slider2 = gr.Slider(1, 10, value=3, step=1, label="max_depth")
predict_button = gr.Button("Predict")
prediction_text = gr.Textbox(label="Prediction")
predict_button.click(
fn=predict_species,
inputs=[
sepal_length_input,
sepal_width_input,
petal_length_input,
petal_width_input,
learning_rate_slider2,
n_estimators_slider2,
max_depth_slider2,
],
outputs=prediction_text
)
demo.launch()
|