Update visualization.py
Browse files- visualization.py +35 -25
visualization.py
CHANGED
@@ -296,35 +296,45 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
296 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
297 |
return None
|
298 |
|
299 |
-
def create_heatmap(t,
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
combined_mse[0] = mse_embeddings_norm
|
|
|
309 |
combined_mse[1] = mse_posture_norm
|
310 |
combined_mse[2] = mse_voice_norm
|
311 |
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
return heatmap_img
|
328 |
|
329 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
330 |
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|
|
|
296 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
297 |
return None
|
298 |
|
299 |
+
def create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, fps, total_frames, width):
|
300 |
+
# Normalize the MSE values
|
301 |
+
mse_embeddings_norm = normalize_mse(mse_embeddings_filtered)
|
302 |
+
mse_posture_norm = normalize_mse(mse_posture_filtered)
|
303 |
+
mse_voice_norm = normalize_mse(mse_voice_filtered)
|
304 |
+
|
305 |
+
# Debug prints
|
306 |
+
print(f"mse_embeddings_norm shape: {mse_embeddings_norm.shape}")
|
307 |
+
print(f"mse_posture_norm shape: {mse_posture_norm.shape}")
|
308 |
+
print(f"mse_voice_norm shape: {mse_voice_norm.shape}")
|
309 |
+
|
310 |
+
# Ensure combined_mse has the correct shape
|
311 |
+
combined_mse = np.zeros((total_frames, width))
|
312 |
+
|
313 |
+
# Adjust shapes and pad with zeros if necessary
|
314 |
+
mse_embeddings_norm = pad_or_trim_array(mse_embeddings_norm, width)
|
315 |
+
mse_posture_norm = pad_or_trim_array(mse_posture_norm, width)
|
316 |
+
mse_voice_norm = pad_or_trim_array(mse_voice_norm, width)
|
317 |
+
|
318 |
combined_mse[0] = mse_embeddings_norm
|
319 |
+
# Assuming you combine posture and voice MSEs similarly
|
320 |
combined_mse[1] = mse_posture_norm
|
321 |
combined_mse[2] = mse_voice_norm
|
322 |
|
323 |
+
# Return or use combined_mse as needed
|
324 |
+
return combined_mse
|
325 |
+
|
326 |
+
def normalize_mse(mse):
|
327 |
+
# Your normalization logic here
|
328 |
+
return mse / np.max(mse)
|
329 |
+
|
330 |
+
def pad_or_trim_array(arr, target_length):
|
331 |
+
if len(arr) > target_length:
|
332 |
+
# Trim the array
|
333 |
+
return arr[:target_length]
|
334 |
+
elif len(arr) < target_length:
|
335 |
+
# Pad the array with zeros
|
336 |
+
return np.pad(arr, (0, target_length - len(arr)), 'constant')
|
337 |
+
return arr
|
|
|
338 |
|
339 |
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
|
340 |
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
|