Update visualization.py
Browse files- visualization.py +10 -15
visualization.py
CHANGED
@@ -321,30 +321,25 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
321 |
|
322 |
# Function to create the correlation heatmap
|
323 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
'Voice MSE': mse_voice
|
328 |
-
}
|
329 |
-
mse_df = pd.DataFrame(mse_data)
|
330 |
-
correlation_matrix = mse_df.corr()
|
331 |
|
332 |
-
plt.figure(figsize=(
|
333 |
-
sns.heatmap(
|
334 |
-
plt.title(
|
335 |
-
plt.
|
336 |
return plt.gcf()
|
337 |
|
338 |
-
# Function to create the 3D scatter plot
|
339 |
def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
|
340 |
fig = plt.figure(figsize=(10, 8))
|
341 |
ax = fig.add_subplot(111, projection='3d')
|
342 |
-
ax.scatter(mse_posture, mse_embeddings, mse_voice, c='b', marker='o')
|
343 |
|
|
|
344 |
ax.set_xlabel('Body Posture MSE')
|
345 |
ax.set_ylabel('Facial Features MSE')
|
346 |
ax.set_zlabel('Voice MSE')
|
347 |
ax.set_title('3D Scatter Plot of MSEs')
|
348 |
|
349 |
-
plt.
|
350 |
-
return
|
|
|
321 |
|
322 |
# Function to create the correlation heatmap
|
323 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
324 |
+
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|
325 |
+
df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
|
326 |
+
corr = df.corr()
|
|
|
|
|
|
|
|
|
327 |
|
328 |
+
plt.figure(figsize=(10, 8))
|
329 |
+
heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
|
330 |
+
plt.title('Correlation Heatmap of MSEs')
|
331 |
+
plt.tight_layout()
|
332 |
return plt.gcf()
|
333 |
|
|
|
334 |
def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
|
335 |
fig = plt.figure(figsize=(10, 8))
|
336 |
ax = fig.add_subplot(111, projection='3d')
|
|
|
337 |
|
338 |
+
ax.scatter(mse_posture, mse_embeddings, mse_voice, c='r', marker='o')
|
339 |
ax.set_xlabel('Body Posture MSE')
|
340 |
ax.set_ylabel('Facial Features MSE')
|
341 |
ax.set_zlabel('Voice MSE')
|
342 |
ax.set_title('3D Scatter Plot of MSEs')
|
343 |
|
344 |
+
plt.tight_layout()
|
345 |
+
return plt.gcf()
|