import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import warnings

from functools import partial
from sklearn.datasets import make_blobs
from sklearn.svm import LinearSVC
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.exceptions import ConvergenceWarning

def train_model(C, n_samples):
    default_base = {"n_samples": 20}

    # Algorithms to compare
    params = default_base.copy()
    params.update({"n_samples":n_samples})

    X, y = make_blobs(n_samples=params["n_samples"], centers=2, random_state=0)
    
    fig, ax = plt.subplots()

    # catch warnings related to convergence
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)

        clf = LinearSVC(C=C, loss="hinge", random_state=42).fit(X, y)
        # obtain the support vectors through the decision function
        decision_function = clf.decision_function(X)
        # we can also calculate the decision function manually
        # decision_function = np.dot(X, clf.coef_[0]) + clf.intercept_[0]
        # The support vectors are the samples that lie within the margin
        # boundaries, whose size is conventionally constrained to 1
        support_vector_indices = np.where(np.abs(decision_function) <= 1 + 1e-15)[0]
        support_vectors = X[support_vector_indices]

        ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
        DecisionBoundaryDisplay.from_estimator(
            clf,
            X,
            ax=ax,
            grid_resolution=50,
            plot_method="contour",
            colors="k",
            levels=[-1, 0, 1],
            alpha=0.5,
            linestyles=["--", "-", "--"],
        )
        ax.scatter(
            support_vectors[:, 0],
            support_vectors[:, 1],
            s=100,
            linewidth=1,
            facecolors="none",
            edgecolors="k",
        )
        ax.set_title("C=" + str(C))

        return fig

def iter_grid(n_rows, n_cols):
    # create a grid using gradio Block
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield

title = "📈 Linear Support Vector Classification"
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown("Unlike SVC (based on LIBSVM), LinearSVC " 
                + "(based on LIBLINEAR) does not provide the" 
                + "support vectors. This example demonstrates" 
                + "how to obtain the support vectors in LinearSVC.")


    input_models = ["Bisecting K-Means", "K-Means"]
    
    n_samples = gr.Slider(minimum=20, maximum=100, step=5, 
    label = "Number of Samples")

    input_model = "LinearSVC"
    # Regularization parameter C included in loop
    for _, C in zip(iter_grid(1,2), [1, 100]):
        plot = gr.Plot(label=input_model)

        fn = partial(train_model, C)
        n_samples.change(fn=fn, inputs=[n_samples], outputs=plot)
        

demo.launch()