Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +10 -15
visualization.py
CHANGED
|
@@ -321,30 +321,25 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
| 321 |
|
| 322 |
# Function to create the correlation heatmap
|
| 323 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
'Voice MSE': mse_voice
|
| 328 |
-
}
|
| 329 |
-
mse_df = pd.DataFrame(mse_data)
|
| 330 |
-
correlation_matrix = mse_df.corr()
|
| 331 |
|
| 332 |
-
plt.figure(figsize=(
|
| 333 |
-
sns.heatmap(
|
| 334 |
-
plt.title(
|
| 335 |
-
plt.
|
| 336 |
return plt.gcf()
|
| 337 |
|
| 338 |
-
# Function to create the 3D scatter plot
|
| 339 |
def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
|
| 340 |
fig = plt.figure(figsize=(10, 8))
|
| 341 |
ax = fig.add_subplot(111, projection='3d')
|
| 342 |
-
ax.scatter(mse_posture, mse_embeddings, mse_voice, c='b', marker='o')
|
| 343 |
|
|
|
|
| 344 |
ax.set_xlabel('Body Posture MSE')
|
| 345 |
ax.set_ylabel('Facial Features MSE')
|
| 346 |
ax.set_zlabel('Voice MSE')
|
| 347 |
ax.set_title('3D Scatter Plot of MSEs')
|
| 348 |
|
| 349 |
-
plt.
|
| 350 |
-
return
|
|
|
|
| 321 |
|
| 322 |
# Function to create the correlation heatmap
|
| 323 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
| 324 |
+
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|
| 325 |
+
df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
|
| 326 |
+
corr = df.corr()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
+
plt.figure(figsize=(10, 8))
|
| 329 |
+
heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
|
| 330 |
+
plt.title('Correlation Heatmap of MSEs')
|
| 331 |
+
plt.tight_layout()
|
| 332 |
return plt.gcf()
|
| 333 |
|
|
|
|
| 334 |
def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
|
| 335 |
fig = plt.figure(figsize=(10, 8))
|
| 336 |
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
| 337 |
|
| 338 |
+
ax.scatter(mse_posture, mse_embeddings, mse_voice, c='r', marker='o')
|
| 339 |
ax.set_xlabel('Body Posture MSE')
|
| 340 |
ax.set_ylabel('Facial Features MSE')
|
| 341 |
ax.set_zlabel('Voice MSE')
|
| 342 |
ax.set_title('3D Scatter Plot of MSEs')
|
| 343 |
|
| 344 |
+
plt.tight_layout()
|
| 345 |
+
return plt.gcf()
|