Spaces:
Runtime error
Runtime error
File size: 3,281 Bytes
53f862b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from functools import partial
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
import numpy as np
from sklearn import datasets, manifold
SEED = 0
N_COMPONENTS = 2
np.random.seed(SEED)
def get_circles(n_samples):
X, color = datasets.make_circles(
n_samples=n_samples,
factor=0.5,
noise=0.05,
random_state=SEED
)
return X, color
def get_s_curve(n_samples):
X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED)
X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy()
return X, color
def get_uniform_grid(n_samples):
x = np.linspace(0, 1, int(np.sqrt(n_samples)))
xx, yy = np.meshgrid(x, x)
X = np.hstack(
[
xx.ravel().reshape(-1, 1),
yy.ravel().reshape(-1, 1),
]
)
color = xx.ravel()
return X, color
DATA_MAPPING = {
'circles': get_circles,
's-curve': get_s_curve,
'uniform grid': get_uniform_grid,
}
def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool):
if isinstance(perplexity, dict):
perplexity = perplexity['value']
else:
perplexity = int(perplexity)
X, color = DATA_MAPPING[dataset](n_samples)
if tsne:
tsne = manifold.TSNE(
n_components=N_COMPONENTS,
init="random",
random_state=0,
perplexity=perplexity,
n_iter=400,
)
Y = tsne.fit_transform(X)
else:
Y = X
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(Y[:, 0], Y[:, 1], c=color)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.axis("tight")
return fig
title = "t-SNE: The effect of various perplexity values on the shape"
description = (
"An illustration of t-SNE on the two concentric circles and the"
"S-curve datasets for different perplexity values."
)
with gr.Blocks(title=title) as demo:
gr.HTML(f"<b>{title}</b>")
gr.Markdown(description)
input_data = gr.Radio(
list(DATA_MAPPING),
value="circles",
label="dataset"
)
n_samples = gr.Slider(
minimum=100,
maximum=1000,
value=150,
step=25,
label='Number of Samples'
)
perplexity = gr.Slider(
minimum=2,
maximum=100,
value=5,
step=1,
label='Perplexity'
)
with gr.Row():
with gr.Column():
plot = gr.Plot(label="Original data")
fn = partial(plot_data, tsne=False)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
with gr.Column():
plot = gr.Plot(label="t-SNE")
fn = partial(plot_data, tsne=True)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.launch()
|