Update clustering.py
Browse files- clustering.py +47 -43
clustering.py
CHANGED
@@ -1,43 +1,47 @@
|
|
1 |
-
from sklearn.cluster import KMeans
|
2 |
-
from sklearn.metrics import silhouette_score
|
3 |
-
from sklearn.preprocessing import StandardScaler
|
4 |
-
import streamlit as st
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
import seaborn as sns
|
7 |
-
import pandas as pd
|
8 |
-
from sklearn.decomposition import PCA
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.cluster import KMeans
|
2 |
+
from sklearn.metrics import silhouette_score
|
3 |
+
from sklearn.preprocessing import StandardScaler
|
4 |
+
import streamlit as st
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import seaborn as sns
|
7 |
+
import pandas as pd
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
|
10 |
+
def summarize_cluster_characteristics(clustered_data, labels, cluster_number):
|
11 |
+
cluster_data = clustered_data[labels == cluster_number]
|
12 |
+
summary = cluster_data.mean().to_dict()
|
13 |
+
return summary
|
14 |
+
|
15 |
+
def perform_clustering(df, n_clusters):
|
16 |
+
df = df.dropna()
|
17 |
+
|
18 |
+
scaler = StandardScaler()
|
19 |
+
df_value_scaled = scaler.fit_transform(df)
|
20 |
+
|
21 |
+
# Apply KMeans with the selected number of clusters
|
22 |
+
model = KMeans(n_clusters=n_clusters, random_state=42)
|
23 |
+
model.fit(df_value_scaled)
|
24 |
+
labels = model.predict(df_value_scaled)
|
25 |
+
score = silhouette_score(df_value_scaled, labels)
|
26 |
+
|
27 |
+
df['Cluster'] = labels
|
28 |
+
return df, score, df_value_scaled, labels, model
|
29 |
+
|
30 |
+
def plot_clusters(df_value_scaled, labels, new_data_point=None):
|
31 |
+
pca = PCA(n_components=2)
|
32 |
+
components = pca.fit_transform(df_value_scaled)
|
33 |
+
df_components = pd.DataFrame(data=components, columns=['PC1', 'PC2'])
|
34 |
+
df_components['Cluster'] = labels
|
35 |
+
|
36 |
+
plt.figure(figsize=(10, 6))
|
37 |
+
sns.scatterplot(x='PC1', y='PC2', hue='Cluster', data=df_components, palette='viridis', s=100, alpha=0.7)
|
38 |
+
|
39 |
+
# Plot new data point if provided
|
40 |
+
if new_data_point is not None:
|
41 |
+
plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='red', marker='o', s=100, label='New Data Point')
|
42 |
+
|
43 |
+
plt.title('Cluster Visualization')
|
44 |
+
plt.xlabel('Principal Component 1')
|
45 |
+
plt.ylabel('Principal Component 2')
|
46 |
+
plt.legend(title='Cluster')
|
47 |
+
st.pyplot(plt)
|