Jensen-holm's picture
clustering plot for kmeans
b378fde
raw
history blame
1.19 kB
import io
import base64
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from cluster.clusterer import Clusterer
matplotlib.use("Agg")
sns.set()
def plot(clusterer: Clusterer, X: np.array) -> None:
cluster_data = clusterer.to_dict(X)["clusters"]
# plot the clusters and data points
fig, ax = plt.subplots(figsize=(8, 6))
for cluster in cluster_data:
sns.scatterplot(
x=[point[0] for point in cluster["points"]],
y=[point[1] for point in cluster["points"]],
label=f"Cluster {cluster['cluster_id']}",
ax=ax,
)
ax.scatter(
x=cluster["centroid"][0],
y=cluster["centroid"][1],
marker="x",
s=100,
linewidth=2,
color="red",
)
ax.legend()
ax.set_title("K-means Clustering")
ax.set_ylabel("Normalized Petal Length (cm)")
ax.set_xlabel("Normalized Petal Length (cm)")
clusterer.plot = plt_bytes(fig)
def plt_bytes(fig) -> str:
buf = io.BytesIO()
fig.savefig(buf, format="png")
plt.close(fig)
return base64.b64encode(buf.getvalue()).decode("utf-8")