File size: 5,121 Bytes
96b98f3
 
e252d2c
43728f4
96b98f3
 
 
 
 
e252d2c
 
 
 
96b98f3
 
 
 
 
e252d2c
96b98f3
 
 
 
 
e252d2c
96b98f3
 
 
 
 
 
 
 
e252d2c
43728f4
e252d2c
 
96b98f3
e252d2c
 
96b98f3
43728f4
e252d2c
 
96b98f3
e252d2c
43728f4
e252d2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96b98f3
 
 
 
 
 
 
 
 
 
 
 
 
 
edfc8c7
 
 
e252d2c
edfc8c7
 
 
 
 
 
e252d2c
edfc8c7
 
 
 
43728f4
edfc8c7
 
 
 
e252d2c
edfc8c7
e252d2c
edfc8c7
e252d2c
43728f4
e252d2c
 
 
edfc8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96b98f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

# This line ensures Matplotlib doesn't try to open windows in certain environments:
matplotlib.use('Agg')

# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names
class_names = iris.target_names

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

def train_and_evaluate(learning_rate, n_estimators, max_depth):
    # Train the model
    clf = GradientBoostingClassifier(
        learning_rate=learning_rate,
        n_estimators=n_estimators,
        max_depth=int(max_depth),
        random_state=42
    )
    clf.fit(X_train, y_train)

    # Predict on test set
    y_pred = clf.predict(X_test)

    # Calculate accuracy
    accuracy = accuracy_score(y_test, y_pred)

    # Calculate confusion matrix
    cm = confusion_matrix(y_test, y_pred)

    # Create a single figure with 2 subplots
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

    # --- Subplot 1: Feature Importances ---
    importances = clf.feature_importances_
    axs[0].barh(range(len(feature_names)), importances, color='skyblue')
    axs[0].set_yticks(range(len(feature_names)))
    axs[0].set_yticklabels(feature_names)
    axs[0].set_xlabel("Importance")
    axs[0].set_title("Feature Importances")

    # --- Subplot 2: Confusion Matrix Heatmap ---
    im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    axs[1].set_title("Confusion Matrix")
    # Add colorbar
    cbar = fig.colorbar(im, ax=axs[1])
    # Tick marks for x/y axes
    axs[1].set_xticks(range(len(class_names)))
    axs[1].set_yticks(range(len(class_names)))
    axs[1].set_xticklabels(class_names, rotation=45, ha="right")
    axs[1].set_yticklabels(class_names)
    axs[1].set_ylabel('True Label')
    axs[1].set_xlabel('Predicted Label')

    # Write the counts in each cell
    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            color = "white" if cm[i, j] > thresh else "black"
            axs[1].text(j, i, format(cm[i, j], "d"),
                        ha="center", va="center", color=color)

    plt.tight_layout()

    # Return textual results + the figure
    results_text = f"Accuracy: {accuracy:.3f}"
    return results_text, fig

def predict_species(sepal_length, sepal_width, petal_length, petal_width,
                    learning_rate, n_estimators, max_depth):
    clf = GradientBoostingClassifier(
        learning_rate=learning_rate,
        n_estimators=n_estimators,
        max_depth=int(max_depth),
        random_state=42
    )
    clf.fit(X_train, y_train)
    user_sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
    prediction = clf.predict(user_sample)[0]
    return f"Predicted species: {class_names[prediction]}"

with gr.Blocks() as demo:
    with gr.Tab("Train & Evaluate"):
        gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")

        learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
        n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
        max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")

        train_button = gr.Button("Train & Evaluate")
        output_text = gr.Textbox(label="Results")
        output_plot = gr.Plot(label="Feature Importances & Confusion Matrix")

        train_button.click(
            fn=train_and_evaluate,
            inputs=[learning_rate_slider, n_estimators_slider, max_depth_slider],
            outputs=[output_text, output_plot],
        )

    with gr.Tab("Predict"):
        gr.Markdown("## Predict Iris Species with GradientBoostingClassifier")

        sepal_length_input = gr.Number(value=5.1, label=feature_names[0])
        sepal_width_input  = gr.Number(value=3.5, label=feature_names[1])
        petal_length_input = gr.Number(value=1.4, label=feature_names[2])
        petal_width_input  = gr.Number(value=0.2, label=feature_names[3])

        learning_rate_slider2   = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
        n_estimators_slider2    = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
        max_depth_slider2       = gr.Slider(1, 10, value=3, step=1, label="max_depth")

        predict_button = gr.Button("Predict")
        prediction_text = gr.Textbox(label="Prediction")

        predict_button.click(
            fn=predict_species,
            inputs=[
                sepal_length_input,
                sepal_width_input,
                petal_length_input,
                petal_width_input,
                learning_rate_slider2,
                n_estimators_slider2,
                max_depth_slider2,
            ],
            outputs=prediction_text
        )

demo.launch()