Update visualization.py
Browse files- visualization.py +7 -4
visualization.py
CHANGED
@@ -232,10 +232,14 @@ def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voi
|
|
232 |
# Create a mask for the most frequent person frames
|
233 |
mask = df['Frame'].isin(most_frequent_person_frames)
|
234 |
|
|
|
|
|
|
|
|
|
235 |
# Apply the mask to filter the MSE arrays
|
236 |
-
mse_embeddings_filtered = np.where(
|
237 |
-
mse_posture_filtered = np.where(
|
238 |
-
mse_voice_filtered = np.where(
|
239 |
|
240 |
return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
|
241 |
|
@@ -292,7 +296,6 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
292 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
293 |
return None
|
294 |
|
295 |
-
# Define the create_heatmap function
|
296 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
297 |
frame_count = int(t * video_fps)
|
298 |
|
|
|
232 |
# Create a mask for the most frequent person frames
|
233 |
mask = df['Frame'].isin(most_frequent_person_frames)
|
234 |
|
235 |
+
# Pad mask to match the length of the video frames
|
236 |
+
padded_mask = np.zeros(len(mse_embeddings), dtype=bool)
|
237 |
+
padded_mask[:len(mask)] = mask
|
238 |
+
|
239 |
# Apply the mask to filter the MSE arrays
|
240 |
+
mse_embeddings_filtered = np.where(padded_mask, mse_embeddings, 0)
|
241 |
+
mse_posture_filtered = np.where(padded_mask, mse_posture, 0)
|
242 |
+
mse_voice_filtered = np.where(padded_mask, mse_voice, 0)
|
243 |
|
244 |
return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
|
245 |
|
|
|
296 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
297 |
return None
|
298 |
|
|
|
299 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
300 |
frame_count = int(t * video_fps)
|
301 |
|