reab5555 commited on
Commit
6bd6cac
·
verified ·
1 Parent(s): 395572b

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +2 -2
visualization.py CHANGED
@@ -234,7 +234,7 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
234
 
235
  def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
236
  plt.figure(figsize=(20, 6), dpi=300)
237
- fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 6), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
238
 
239
  # Face heatmap
240
  sns.heatmap(mse_face.reshape(1, -1), cmap='Reds', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
@@ -290,7 +290,7 @@ def create_video_with_heatmap(video_path, mse_face, mse_posture, mse_voice, df,
290
  # Calculate the position of the vertical line
291
  line_pos = int(t / video.duration * video.w)
292
 
293
- # Add the vertical line to the heatmap
294
  heatmap_with_line = heatmap_resized.copy()
295
  cv2.line(heatmap_with_line, (line_pos, 0), (line_pos, heatmap_height), (0, 0, 0), 2)
296
 
 
234
 
235
  def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
236
  plt.figure(figsize=(20, 6), dpi=300)
237
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 8), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
238
 
239
  # Face heatmap
240
  sns.heatmap(mse_face.reshape(1, -1), cmap='Reds', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
 
290
  # Calculate the position of the vertical line
291
  line_pos = int(t / video.duration * video.w)
292
 
293
+ # Add the vertical line to the heatmap only
294
  heatmap_with_line = heatmap_resized.copy()
295
  cv2.line(heatmap_with_line, (line_pos, 0), (line_pos, heatmap_height), (0, 0, 0), 2)
296