Update visualization.py
Browse files- visualization.py +4 -4
visualization.py
CHANGED
@@ -317,13 +317,13 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
317 |
|
318 |
# Function to create the correlation heatmap
|
319 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
320 |
-
data
|
321 |
-
df
|
322 |
-
corr
|
323 |
|
324 |
plt.figure(figsize=(10, 8), dpi=300)
|
325 |
|
326 |
-
heatmap
|
327 |
plt.title('Correlation Heatmap of MSEs')
|
328 |
plt.tight_layout()
|
329 |
return plt.gcf()
|
|
|
317 |
|
318 |
# Function to create the correlation heatmap
|
319 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
320 |
+
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|
321 |
+
df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
|
322 |
+
corr = df.corr()
|
323 |
|
324 |
plt.figure(figsize=(10, 8), dpi=300)
|
325 |
|
326 |
+
heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
|
327 |
plt.title('Correlation Heatmap of MSEs')
|
328 |
plt.tight_layout()
|
329 |
return plt.gcf()
|