Update visualization.py
Browse files- visualization.py +2 -2
visualization.py
CHANGED
@@ -234,7 +234,7 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
|
234 |
|
235 |
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
|
236 |
plt.figure(figsize=(20, 6), dpi=300)
|
237 |
-
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20,
|
238 |
|
239 |
# Face heatmap
|
240 |
sns.heatmap(mse_face.reshape(1, -1), cmap='Reds', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
|
@@ -290,7 +290,7 @@ def create_video_with_heatmap(video_path, mse_face, mse_posture, mse_voice, df,
|
|
290 |
# Calculate the position of the vertical line
|
291 |
line_pos = int(t / video.duration * video.w)
|
292 |
|
293 |
-
# Add the vertical line to the heatmap
|
294 |
heatmap_with_line = heatmap_resized.copy()
|
295 |
cv2.line(heatmap_with_line, (line_pos, 0), (line_pos, heatmap_height), (0, 0, 0), 2)
|
296 |
|
|
|
234 |
|
235 |
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
|
236 |
plt.figure(figsize=(20, 6), dpi=300)
|
237 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 8), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
|
238 |
|
239 |
# Face heatmap
|
240 |
sns.heatmap(mse_face.reshape(1, -1), cmap='Reds', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
|
|
|
290 |
# Calculate the position of the vertical line
|
291 |
line_pos = int(t / video.duration * video.w)
|
292 |
|
293 |
+
# Add the vertical line to the heatmap only
|
294 |
heatmap_with_line = heatmap_resized.copy()
|
295 |
cv2.line(heatmap_with_line, (line_pos, 0), (line_pos, heatmap_height), (0, 0, 0), 2)
|
296 |
|