caliex commited on
Commit
02adacc
·
1 Parent(s): 10cd465

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.colors import LogNorm
4
+ from sklearn import mixture
5
+ import gradio as gr
6
+ import tempfile
7
+ import os
8
+
9
+ def generate_gaussian_mixture(n_samples):
10
+ # generate random sample, two components
11
+ np.random.seed(0)
12
+
13
+ # generate spherical data centered on (20, 20)
14
+ shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20])
15
+
16
+ # generate zero centered stretched Gaussian data
17
+ C = np.array([[0.0, -0.7], [3.5, 0.7]])
18
+ stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C)
19
+
20
+ # concatenate the two datasets into the final training set
21
+ X_train = np.vstack([shifted_gaussian, stretched_gaussian])
22
+
23
+ # fit a Gaussian Mixture Model with two components
24
+ clf = mixture.GaussianMixture(n_components=2, covariance_type="full")
25
+ clf.fit(X_train)
26
+
27
+ # display predicted scores by the model as a contour plot
28
+ x = np.linspace(-20.0, 30.0)
29
+ y = np.linspace(-20.0, 40.0)
30
+ X, Y = np.meshgrid(x, y)
31
+ XX = np.array([X.ravel(), Y.ravel()]).T
32
+ Z = -clf.score_samples(XX)
33
+ Z = Z.reshape(X.shape)
34
+
35
+ fig, ax = plt.subplots()
36
+ CS = ax.contour(
37
+ X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10)
38
+ )
39
+ CB = fig.colorbar(CS, shrink=0.8, extend="both")
40
+ ax.scatter(X_train[:, 0], X_train[:, 1], 0.8)
41
+
42
+ ax.set_title("Negative log-likelihood predicted by a GMM")
43
+ ax.axis("tight")
44
+
45
+ # Save the plot as a temporary image file
46
+ temp_dir = tempfile.mkdtemp()
47
+ temp_file_path = os.path.join(temp_dir, "gmm_plot.png")
48
+ fig.savefig(temp_file_path)
49
+ plt.close(fig)
50
+
51
+ return temp_file_path
52
+
53
+ def plot_to_image(file_path):
54
+ with open(file_path, "rb") as f:
55
+ image_bytes = f.read()
56
+ os.remove(file_path)
57
+ return image_bytes
58
+
59
+ inputs = gr.inputs.Slider(100, 1000, step=100, default=300, label="Number of Samples")
60
+ outputs = gr.outputs.Image(type="pil", label="GMM Plot")
61
+
62
+ title = "Density Estimation for a Gaussian mixture"
63
+ description = "In this example, you can visualize the density estimation of a mixture of two Gaussians using a Gaussian Mixture Model (GMM). The data used for the model is generated from two Gaussians with distinct centers and covariance matrices. By adjusting the number of samples, you can observe how the GMM captures the underlying distribution and generates a contour plot representing the estimated density. This interactive application allows you to explore the behavior of the GMM and gain insights into the modeling of complex data distributions using mixture models. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_pdf.html"
64
+ gr.Interface(generate_gaussian_mixture, inputs, outputs, title=title, description=description, postprocess=plot_to_image).launch()