Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +7 -4
visualization.py
CHANGED
|
@@ -232,10 +232,14 @@ def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voi
|
|
| 232 |
# Create a mask for the most frequent person frames
|
| 233 |
mask = df['Frame'].isin(most_frequent_person_frames)
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
# Apply the mask to filter the MSE arrays
|
| 236 |
-
mse_embeddings_filtered = np.where(
|
| 237 |
-
mse_posture_filtered = np.where(
|
| 238 |
-
mse_voice_filtered = np.where(
|
| 239 |
|
| 240 |
return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
|
| 241 |
|
|
@@ -292,7 +296,6 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
| 292 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
| 293 |
return None
|
| 294 |
|
| 295 |
-
# Define the create_heatmap function
|
| 296 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
| 297 |
frame_count = int(t * video_fps)
|
| 298 |
|
|
|
|
| 232 |
# Create a mask for the most frequent person frames
|
| 233 |
mask = df['Frame'].isin(most_frequent_person_frames)
|
| 234 |
|
| 235 |
+
# Pad mask to match the length of the video frames
|
| 236 |
+
padded_mask = np.zeros(len(mse_embeddings), dtype=bool)
|
| 237 |
+
padded_mask[:len(mask)] = mask
|
| 238 |
+
|
| 239 |
# Apply the mask to filter the MSE arrays
|
| 240 |
+
mse_embeddings_filtered = np.where(padded_mask, mse_embeddings, 0)
|
| 241 |
+
mse_posture_filtered = np.where(padded_mask, mse_posture, 0)
|
| 242 |
+
mse_voice_filtered = np.where(padded_mask, mse_voice, 0)
|
| 243 |
|
| 244 |
return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
|
| 245 |
|
|
|
|
| 296 |
print(f"Failed to create heatmap video at: {heatmap_video_path}")
|
| 297 |
return None
|
| 298 |
|
|
|
|
| 299 |
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
| 300 |
frame_count = int(t * video_fps)
|
| 301 |
|