Hnabil's picture
Update app.py
2cecf9a
raw
history blame
4.1 kB
from functools import partial
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
import numpy as np
from sklearn import datasets, manifold
SEED = 0
N_COMPONENTS = 2
np.random.seed(SEED)
def get_circles(n_samples):
X, color = datasets.make_circles(
n_samples=n_samples,
factor=0.5,
noise=0.05,
random_state=SEED
)
return X, color
def get_s_curve(n_samples):
X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED)
X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy()
return X, color
def get_uniform_grid(n_samples):
x = np.linspace(0, 1, int(np.sqrt(n_samples)))
xx, yy = np.meshgrid(x, x)
X = np.hstack(
[
xx.ravel().reshape(-1, 1),
yy.ravel().reshape(-1, 1),
]
)
color = xx.ravel()
return X, color
DATA_MAPPING = {
'Circles': get_circles,
'S-curve': get_s_curve,
'Uniform Grid': get_uniform_grid,
}
def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool):
if isinstance(perplexity, dict):
perplexity = perplexity['value']
else:
perplexity = int(perplexity)
X, color = DATA_MAPPING[dataset](n_samples)
if tsne:
tsne = manifold.TSNE(
n_components=N_COMPONENTS,
init="random",
random_state=0,
perplexity=perplexity,
n_iter=400,
)
Y = tsne.fit_transform(X)
else:
Y = X
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(Y[:, 0], Y[:, 1], c=color)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.axis("tight")
return fig
title = "t-SNE: The effect of various perplexity values on the shape"
description = """
t-Stochastic Neighborhood Embedding ([t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html)) is a powerful technique dimensionality reduction and visualization of high dimensional datasets.
One of the key parameters in t-SNE is perplexity, which controls the number of nearest neighbors used to represent each data point in the low-dimensional space.
In this illustration, we explore the impact of various perplexity values on t-SNE visualizations using three commonly used datasets: Concentric Circles, S-curve and Uniform Grid.
By comparing the resulting visualizations, we demonstrate how changing the perplexity value affects the shape of the visualization.
Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/manifold/plot_t_sne_perplexity.html)
"""
with gr.Blocks(title=title) as demo:
gr.HTML(f"<b>{title}</b>")
gr.Markdown(description)
input_data = gr.Radio(
list(DATA_MAPPING),
value="Circles",
label="dataset"
)
n_samples = gr.Slider(
minimum=100,
maximum=1000,
value=150,
step=25,
label='Number of Samples'
)
perplexity = gr.Slider(
minimum=2,
maximum=100,
value=5,
step=1,
label='Perplexity'
)
with gr.Row():
with gr.Column():
plot = gr.Plot(label="Original data")
fn = partial(plot_data, tsne=False)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
with gr.Column():
plot = gr.Plot(label="t-SNE")
fn = partial(plot_data, tsne=True)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.launch()