reab5555 commited on
Commit
8e8d1ed
·
verified ·
1 Parent(s): 320dc6d

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +7 -4
visualization.py CHANGED
@@ -232,10 +232,14 @@ def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voi
232
  # Create a mask for the most frequent person frames
233
  mask = df['Frame'].isin(most_frequent_person_frames)
234
 
 
 
 
 
235
  # Apply the mask to filter the MSE arrays
236
- mse_embeddings_filtered = np.where(mask, mse_embeddings, 0)
237
- mse_posture_filtered = np.where(mask, mse_posture, 0)
238
- mse_voice_filtered = np.where(mask, mse_voice, 0)
239
 
240
  return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
241
 
@@ -292,7 +296,6 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
292
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
293
  return None
294
 
295
- # Define the create_heatmap function
296
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
297
  frame_count = int(t * video_fps)
298
 
 
232
  # Create a mask for the most frequent person frames
233
  mask = df['Frame'].isin(most_frequent_person_frames)
234
 
235
+ # Pad mask to match the length of the video frames
236
+ padded_mask = np.zeros(len(mse_embeddings), dtype=bool)
237
+ padded_mask[:len(mask)] = mask
238
+
239
  # Apply the mask to filter the MSE arrays
240
+ mse_embeddings_filtered = np.where(padded_mask, mse_embeddings, 0)
241
+ mse_posture_filtered = np.where(padded_mask, mse_posture, 0)
242
+ mse_voice_filtered = np.where(padded_mask, mse_voice, 0)
243
 
244
  return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
245
 
 
296
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
297
  return None
298
 
 
299
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
300
  frame_count = int(t * video_fps)
301