Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	fix visualisations + add heatmap
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,22 +1,28 @@ | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            import numpy as np
         | 
|  | |
| 3 | 
             
            import matplotlib.pyplot as plt
         | 
| 4 | 
             
            from sklearn.datasets import load_iris
         | 
| 5 | 
             
            from sklearn.ensemble import GradientBoostingClassifier
         | 
| 6 | 
             
            from sklearn.model_selection import train_test_split
         | 
| 7 | 
             
            from sklearn.metrics import accuracy_score, confusion_matrix
         | 
| 8 |  | 
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
            iris = load_iris()
         | 
| 10 | 
             
            X, y = iris.data, iris.target
         | 
| 11 | 
             
            feature_names = iris.feature_names
         | 
| 12 | 
             
            class_names = iris.target_names
         | 
| 13 |  | 
|  | |
| 14 | 
             
            X_train, X_test, y_train, y_test = train_test_split(
         | 
| 15 | 
             
                X, y, test_size=0.3, random_state=42
         | 
| 16 | 
             
            )
         | 
| 17 |  | 
| 18 | 
             
            def train_and_evaluate(learning_rate, n_estimators, max_depth):
         | 
| 19 | 
            -
                # Train model
         | 
| 20 | 
             
                clf = GradientBoostingClassifier(
         | 
| 21 | 
             
                    learning_rate=learning_rate,
         | 
| 22 | 
             
                    n_estimators=n_estimators,
         | 
| @@ -25,29 +31,52 @@ def train_and_evaluate(learning_rate, n_estimators, max_depth): | |
| 25 | 
             
                )
         | 
| 26 | 
             
                clf.fit(X_train, y_train)
         | 
| 27 |  | 
| 28 | 
            -
                # Predict  | 
| 29 | 
             
                y_pred = clf.predict(X_test)
         | 
|  | |
|  | |
| 30 | 
             
                accuracy = accuracy_score(y_test, y_pred)
         | 
|  | |
|  | |
| 31 | 
             
                cm = confusion_matrix(y_test, y_pred)
         | 
| 32 |  | 
| 33 | 
            -
                #  | 
| 34 | 
            -
                 | 
| 35 |  | 
| 36 | 
            -
                #  | 
| 37 | 
             
                importances = clf.feature_importances_
         | 
| 38 | 
            -
                 | 
| 39 | 
            -
                 | 
| 40 | 
            -
                 | 
| 41 | 
            -
                 | 
| 42 | 
            -
                 | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
                 | 
| 46 | 
            -
                 | 
| 47 | 
            -
                 | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 51 |  | 
| 52 | 
             
            def predict_species(sepal_length, sepal_width, petal_length, petal_width,
         | 
| 53 | 
             
                                learning_rate, n_estimators, max_depth):
         | 
| @@ -65,13 +94,14 @@ def predict_species(sepal_length, sepal_width, petal_length, petal_width, | |
| 65 | 
             
            with gr.Blocks() as demo:
         | 
| 66 | 
             
                with gr.Tab("Train & Evaluate"):
         | 
| 67 | 
             
                    gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
         | 
|  | |
| 68 | 
             
                    learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
         | 
| 69 | 
             
                    n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
         | 
| 70 | 
             
                    max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
         | 
| 71 |  | 
| 72 | 
             
                    train_button = gr.Button("Train & Evaluate")
         | 
| 73 | 
             
                    output_text = gr.Textbox(label="Results")
         | 
| 74 | 
            -
                    output_plot = gr.Plot(label="Feature  | 
| 75 |  | 
| 76 | 
             
                    train_button.click(
         | 
| 77 | 
             
                        fn=train_and_evaluate,
         | 
| @@ -81,14 +111,15 @@ with gr.Blocks() as demo: | |
| 81 |  | 
| 82 | 
             
                with gr.Tab("Predict"):
         | 
| 83 | 
             
                    gr.Markdown("## Predict Iris Species with GradientBoostingClassifier")
         | 
|  | |
| 84 | 
             
                    sepal_length_input = gr.Number(value=5.1, label=feature_names[0])
         | 
| 85 | 
            -
                    sepal_width_input | 
| 86 | 
             
                    petal_length_input = gr.Number(value=1.4, label=feature_names[2])
         | 
| 87 | 
            -
                    petal_width_input | 
| 88 |  | 
| 89 | 
            -
                    learning_rate_slider2 | 
| 90 | 
            -
                    n_estimators_slider2 | 
| 91 | 
            -
                    max_depth_slider2 | 
| 92 |  | 
| 93 | 
             
                    predict_button = gr.Button("Predict")
         | 
| 94 | 
             
                    prediction_text = gr.Textbox(label="Prediction")
         | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            import numpy as np
         | 
| 3 | 
            +
            import matplotlib
         | 
| 4 | 
             
            import matplotlib.pyplot as plt
         | 
| 5 | 
             
            from sklearn.datasets import load_iris
         | 
| 6 | 
             
            from sklearn.ensemble import GradientBoostingClassifier
         | 
| 7 | 
             
            from sklearn.model_selection import train_test_split
         | 
| 8 | 
             
            from sklearn.metrics import accuracy_score, confusion_matrix
         | 
| 9 |  | 
| 10 | 
            +
            # This line ensures Matplotlib doesn't try to open windows in certain environments:
         | 
| 11 | 
            +
            matplotlib.use('Agg')
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Load the Iris dataset
         | 
| 14 | 
             
            iris = load_iris()
         | 
| 15 | 
             
            X, y = iris.data, iris.target
         | 
| 16 | 
             
            feature_names = iris.feature_names
         | 
| 17 | 
             
            class_names = iris.target_names
         | 
| 18 |  | 
| 19 | 
            +
            # Train/test split
         | 
| 20 | 
             
            X_train, X_test, y_train, y_test = train_test_split(
         | 
| 21 | 
             
                X, y, test_size=0.3, random_state=42
         | 
| 22 | 
             
            )
         | 
| 23 |  | 
| 24 | 
             
            def train_and_evaluate(learning_rate, n_estimators, max_depth):
         | 
| 25 | 
            +
                # Train the model
         | 
| 26 | 
             
                clf = GradientBoostingClassifier(
         | 
| 27 | 
             
                    learning_rate=learning_rate,
         | 
| 28 | 
             
                    n_estimators=n_estimators,
         | 
|  | |
| 31 | 
             
                )
         | 
| 32 | 
             
                clf.fit(X_train, y_train)
         | 
| 33 |  | 
| 34 | 
            +
                # Predict on test set
         | 
| 35 | 
             
                y_pred = clf.predict(X_test)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # Calculate accuracy
         | 
| 38 | 
             
                accuracy = accuracy_score(y_test, y_pred)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # Calculate confusion matrix
         | 
| 41 | 
             
                cm = confusion_matrix(y_test, y_pred)
         | 
| 42 |  | 
| 43 | 
            +
                # Create a single figure with 2 subplots
         | 
| 44 | 
            +
                fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))
         | 
| 45 |  | 
| 46 | 
            +
                # --- Subplot 1: Feature Importances ---
         | 
| 47 | 
             
                importances = clf.feature_importances_
         | 
| 48 | 
            +
                axs[0].barh(range(len(feature_names)), importances, color='skyblue')
         | 
| 49 | 
            +
                axs[0].set_yticks(range(len(feature_names)))
         | 
| 50 | 
            +
                axs[0].set_yticklabels(feature_names)
         | 
| 51 | 
            +
                axs[0].set_xlabel("Importance")
         | 
| 52 | 
            +
                axs[0].set_title("Feature Importances")
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # --- Subplot 2: Confusion Matrix Heatmap ---
         | 
| 55 | 
            +
                im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
         | 
| 56 | 
            +
                axs[1].set_title("Confusion Matrix")
         | 
| 57 | 
            +
                # Add colorbar
         | 
| 58 | 
            +
                cbar = fig.colorbar(im, ax=axs[1])
         | 
| 59 | 
            +
                # Tick marks for x/y axes
         | 
| 60 | 
            +
                axs[1].set_xticks(range(len(class_names)))
         | 
| 61 | 
            +
                axs[1].set_yticks(range(len(class_names)))
         | 
| 62 | 
            +
                axs[1].set_xticklabels(class_names, rotation=45, ha="right")
         | 
| 63 | 
            +
                axs[1].set_yticklabels(class_names)
         | 
| 64 | 
            +
                axs[1].set_ylabel('True Label')
         | 
| 65 | 
            +
                axs[1].set_xlabel('Predicted Label')
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                # Write the counts in each cell
         | 
| 68 | 
            +
                thresh = cm.max() / 2.0
         | 
| 69 | 
            +
                for i in range(cm.shape[0]):
         | 
| 70 | 
            +
                    for j in range(cm.shape[1]):
         | 
| 71 | 
            +
                        color = "white" if cm[i, j] > thresh else "black"
         | 
| 72 | 
            +
                        axs[1].text(j, i, format(cm[i, j], "d"),
         | 
| 73 | 
            +
                                    ha="center", va="center", color=color)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                plt.tight_layout()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # Return textual results + the figure
         | 
| 78 | 
            +
                results_text = f"Accuracy: {accuracy:.3f}"
         | 
| 79 | 
            +
                return results_text, fig
         | 
| 80 |  | 
| 81 | 
             
            def predict_species(sepal_length, sepal_width, petal_length, petal_width,
         | 
| 82 | 
             
                                learning_rate, n_estimators, max_depth):
         | 
|  | |
| 94 | 
             
            with gr.Blocks() as demo:
         | 
| 95 | 
             
                with gr.Tab("Train & Evaluate"):
         | 
| 96 | 
             
                    gr.Markdown("## Train a GradientBoostingClassifier on the Iris dataset")
         | 
| 97 | 
            +
             | 
| 98 | 
             
                    learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
         | 
| 99 | 
             
                    n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
         | 
| 100 | 
             
                    max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
         | 
| 101 |  | 
| 102 | 
             
                    train_button = gr.Button("Train & Evaluate")
         | 
| 103 | 
             
                    output_text = gr.Textbox(label="Results")
         | 
| 104 | 
            +
                    output_plot = gr.Plot(label="Feature Importances & Confusion Matrix")
         | 
| 105 |  | 
| 106 | 
             
                    train_button.click(
         | 
| 107 | 
             
                        fn=train_and_evaluate,
         | 
|  | |
| 111 |  | 
| 112 | 
             
                with gr.Tab("Predict"):
         | 
| 113 | 
             
                    gr.Markdown("## Predict Iris Species with GradientBoostingClassifier")
         | 
| 114 | 
            +
             | 
| 115 | 
             
                    sepal_length_input = gr.Number(value=5.1, label=feature_names[0])
         | 
| 116 | 
            +
                    sepal_width_input  = gr.Number(value=3.5, label=feature_names[1])
         | 
| 117 | 
             
                    petal_length_input = gr.Number(value=1.4, label=feature_names[2])
         | 
| 118 | 
            +
                    petal_width_input  = gr.Number(value=0.2, label=feature_names[3])
         | 
| 119 |  | 
| 120 | 
            +
                    learning_rate_slider2   = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
         | 
| 121 | 
            +
                    n_estimators_slider2    = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
         | 
| 122 | 
            +
                    max_depth_slider2       = gr.Slider(1, 10, value=3, step=1, label="max_depth")
         | 
| 123 |  | 
| 124 | 
             
                    predict_button = gr.Button("Predict")
         | 
| 125 | 
             
                    prediction_text = gr.Textbox(label="Prediction")
         | 

