caliex commited on
Commit
616f933
·
1 Parent(s): d365bcb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib as mpl
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from sklearn import datasets
6
+ from sklearn.mixture import GaussianMixture
7
+ from sklearn.model_selection import StratifiedKFold
8
+
9
+ colors = ["navy", "turquoise", "darkorange"]
10
+
11
+
12
+ def make_ellipses(gmm, ax):
13
+ for n, color in enumerate(colors):
14
+ if gmm.covariance_type == "full":
15
+ covariances = gmm.covariances_[n][:2, :2]
16
+ elif gmm.covariance_type == "tied":
17
+ covariances = gmm.covariances_[:2, :2]
18
+ elif gmm.covariance_type == "diag":
19
+ covariances = np.diag(gmm.covariances_[n][:2])
20
+ elif gmm.covariance_type == "spherical":
21
+ covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]
22
+ v, w = np.linalg.eigh(covariances)
23
+ u = w[0] / np.linalg.norm(w[0])
24
+ angle = np.arctan2(u[1], u[0])
25
+ angle = 180 * angle / np.pi # convert to degrees
26
+ v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
27
+ ell = mpl.patches.Ellipse(
28
+ gmm.means_[n, :2], v[0], v[1], angle=180 + angle, color=color
29
+ )
30
+ ell.set_clip_box(ax.bbox)
31
+ ell.set_alpha(0.5)
32
+ ax.add_artist(ell)
33
+ ax.set_aspect("equal", "datalim")
34
+
35
+
36
+ def classify_iris(cov_type):
37
+ iris = datasets.load_iris()
38
+
39
+ # Break up the dataset into non-overlapping training (75%) and testing
40
+ # (25%) sets.
41
+ skf = StratifiedKFold(n_splits=4)
42
+ # Only take the first fold.
43
+ train_index, test_index = next(iter(skf.split(iris.data, iris.target)))
44
+
45
+ X_train = iris.data[train_index]
46
+ y_train = iris.target[train_index]
47
+ X_test = iris.data[test_index]
48
+ y_test = iris.target[test_index]
49
+
50
+ n_classes = len(np.unique(y_train))
51
+
52
+ # Try GMMs using different types of covariances.
53
+ estimator = GaussianMixture(
54
+ n_components=n_classes, covariance_type=cov_type, max_iter=20, random_state=0
55
+ )
56
+
57
+ # Since we have class labels for the training data, we can
58
+ # initialize the GMM parameters in a supervised manner.
59
+ estimator.means_init = np.array(
60
+ [X_train[y_train == i].mean(axis=0) for i in range(n_classes)]
61
+ )
62
+
63
+ # Train the other parameters using the EM algorithm.
64
+ estimator.fit(X_train)
65
+
66
+ fig, ax = plt.subplots(figsize=(8, 6))
67
+
68
+ make_ellipses(estimator, ax)
69
+
70
+ for n, color in enumerate(colors):
71
+ data = iris.data[iris.target == n]
72
+ ax.scatter(data[:, 0], data[:, 1], s=0.8, color=color, label=iris.target_names[n])
73
+
74
+ # Plot the test data with crosses
75
+ for n, color in enumerate(colors):
76
+ data = X_test[y_test == n]
77
+ ax.scatter(data[:, 0], data[:, 1], marker="x", color=color)
78
+
79
+ y_train_pred = estimator.predict(X_train)
80
+ train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100
81
+ ax.text(0.05, 0.9, "Train accuracy: %.1f" % train_accuracy, transform=ax.transAxes)
82
+
83
+ y_test_pred = estimator.predict(X_test)
84
+ test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100
85
+ ax.text(0.05, 0.8, "Test accuracy: %.1f" % test_accuracy, transform=ax.transAxes)
86
+
87
+ ax.set_xticks(())
88
+ ax.set_yticks(())
89
+ ax.set_title(cov_type.capitalize())
90
+
91
+ plt.legend(scatterpoints=1, loc="lower right", prop=dict(size=12))
92
+
93
+ # Save the plot to a file and return its path
94
+ output_path = "classification_plot.png"
95
+ plt.savefig(output_path)
96
+ plt.close()
97
+
98
+ return output_path
99
+
100
+
101
+ iface = gr.Interface(
102
+ fn=classify_iris,
103
+ inputs=gr.inputs.Radio(["spherical", "diag", "tied", "full"], label="Covariance Type"),
104
+ outputs="image",
105
+ title="Gaussian Mixture Model Covariance",
106
+ description="Explore different covariance types for Gaussian mixture models (GMMs) in this demonstration. GMMs are commonly used for clustering, but in this example, we compare the obtained clusters with the actual classes from the dataset. By initializing the means of the Gaussians with the means of the classes in the training set, we ensure a valid comparison. The plots show the predicted labels on both training and test data using GMMs with spherical, diagonal, full, and tied covariance matrices. Interestingly, while full covariance is expected to perform best, it may overfit small datasets and struggle to generalize to held out test data. See the original scikit-learn example for more information: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html",
107
+ examples=[
108
+ ["spherical"],
109
+ ["diag"],
110
+ ["tied"],
111
+ ["full"],
112
+ ],
113
+ )
114
+
115
+ iface.launch()