reab5555 commited on
Commit
ac5de2e
·
verified ·
1 Parent(s): eeed558

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +3 -3
visualization.py CHANGED
@@ -236,17 +236,17 @@ def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combi
236
  fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 6), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
237
 
238
  # Face heatmap
239
- sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
240
  ax1.set_ylabel('Face', rotation=0, ha='right', va='center')
241
  ax1.yaxis.set_label_coords(-0.01, 0.5)
242
 
243
  # Posture heatmap
244
- sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2, xticklabels=False, yticklabels=False)
245
  ax2.set_ylabel('Posture', rotation=0, ha='right', va='center')
246
  ax2.yaxis.set_label_coords(-0.01, 0.5)
247
 
248
  # Voice heatmap
249
- sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3, yticklabels=False)
250
  ax3.set_ylabel('Voice', rotation=0, ha='right', va='center')
251
  ax3.yaxis.set_label_coords(-0.01, 0.5)
252
 
 
236
  fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 6), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
237
 
238
  # Face heatmap
239
+ sns.heatmap(mse_face.reshape(1, -1), cmap='Reds', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
240
  ax1.set_ylabel('Face', rotation=0, ha='right', va='center')
241
  ax1.yaxis.set_label_coords(-0.01, 0.5)
242
 
243
  # Posture heatmap
244
+ sns.heatmap(mse_posture.reshape(1, -1), cmap='Reds', cbar=False, ax=ax2, xticklabels=False, yticklabels=False)
245
  ax2.set_ylabel('Posture', rotation=0, ha='right', va='center')
246
  ax2.yaxis.set_label_coords(-0.01, 0.5)
247
 
248
  # Voice heatmap
249
+ sns.heatmap(mse_voice.reshape(1, -1), cmap='Reds', cbar=False, ax=ax3, yticklabels=False)
250
  ax3.set_ylabel('Voice', rotation=0, ha='right', va='center')
251
  ax3.yaxis.set_label_coords(-0.01, 0.5)
252