reab5555 commited on
Commit
99614c2
·
verified ·
1 Parent(s): bc0cbbf

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +10 -15
visualization.py CHANGED
@@ -321,30 +321,25 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
321
 
322
  # Function to create the correlation heatmap
323
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
324
- mse_data = {
325
- 'Facial Features MSE': mse_embeddings,
326
- 'Body Posture MSE': mse_posture,
327
- 'Voice MSE': mse_voice
328
- }
329
- mse_df = pd.DataFrame(mse_data)
330
- correlation_matrix = mse_df.corr()
331
 
332
- plt.figure(figsize=(8, 6))
333
- sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
334
- plt.title("Correlation Heatmap of MSEs")
335
- plt.close()
336
  return plt.gcf()
337
 
338
- # Function to create the 3D scatter plot
339
  def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
340
  fig = plt.figure(figsize=(10, 8))
341
  ax = fig.add_subplot(111, projection='3d')
342
- ax.scatter(mse_posture, mse_embeddings, mse_voice, c='b', marker='o')
343
 
 
344
  ax.set_xlabel('Body Posture MSE')
345
  ax.set_ylabel('Facial Features MSE')
346
  ax.set_zlabel('Voice MSE')
347
  ax.set_title('3D Scatter Plot of MSEs')
348
 
349
- plt.close()
350
- return fig
 
321
 
322
  # Function to create the correlation heatmap
323
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
324
+ data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
325
+ df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
326
+ corr = df.corr()
 
 
 
 
327
 
328
+ plt.figure(figsize=(10, 8))
329
+ heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
330
+ plt.title('Correlation Heatmap of MSEs')
331
+ plt.tight_layout()
332
  return plt.gcf()
333
 
 
334
  def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
335
  fig = plt.figure(figsize=(10, 8))
336
  ax = fig.add_subplot(111, projection='3d')
 
337
 
338
+ ax.scatter(mse_posture, mse_embeddings, mse_voice, c='r', marker='o')
339
  ax.set_xlabel('Body Posture MSE')
340
  ax.set_ylabel('Facial Features MSE')
341
  ax.set_zlabel('Voice MSE')
342
  ax.set_title('3D Scatter Plot of MSEs')
343
 
344
+ plt.tight_layout()
345
+ return plt.gcf()