Spaces:
Sleeping
Sleeping
import io | |
import base64 | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from cluster.clusterer import Clusterer | |
matplotlib.use("Agg") | |
sns.set() | |
def plot(clusterer: Clusterer, X: np.array) -> None: | |
cluster_data = clusterer.to_dict(X)["clusters"] | |
# plot the clusters and data points | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
for cluster in cluster_data: | |
sns.scatterplot( | |
x=[point[0] for point in cluster["points"]], | |
y=[point[1] for point in cluster["points"]], | |
label=f"Cluster {cluster['cluster_id']}", | |
ax=ax, | |
) | |
ax.scatter( | |
x=cluster["centroid"][0], | |
y=cluster["centroid"][1], | |
marker="x", | |
s=100, | |
linewidth=2, | |
color="red", | |
) | |
ax.legend() | |
ax.set_title("K-means Clustering") | |
ax.set_ylabel("Normalized Petal Length (cm)") | |
ax.set_xlabel("Normalized Petal Length (cm)") | |
clusterer.plot = plt_bytes(fig) | |
def plt_bytes(fig) -> str: | |
buf = io.BytesIO() | |
fig.savefig(buf, format="png") | |
plt.close(fig) | |
return base64.b64encode(buf.getvalue()).decode("utf-8") | |