reab5555 commited on
Commit
8ad22bc
·
verified ·
1 Parent(s): ee488cb

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +6 -1
visualization.py CHANGED
@@ -219,6 +219,10 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
219
 
220
 
221
  def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voice, most_frequent_person_frames):
 
 
 
 
222
  # Create a mask for the most frequent person frames
223
  mask = df['Frame'].isin(most_frequent_person_frames)
224
 
@@ -229,6 +233,7 @@ def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voi
229
 
230
  return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
231
 
 
232
  def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, most_frequent_person_frames):
233
  print(f"Creating heatmap video. Output folder: {output_folder}")
234
 
@@ -282,7 +287,7 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
282
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
283
  return None
284
 
285
- # Define the create_heatmap function
286
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
287
  frame_count = int(t * video_fps)
288
 
 
219
 
220
 
221
  def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voice, most_frequent_person_frames):
222
+ # Ensure most_frequent_person_frames is a list
223
+ if not isinstance(most_frequent_person_frames, (list, np.ndarray)):
224
+ most_frequent_person_frames = [most_frequent_person_frames]
225
+
226
  # Create a mask for the most frequent person frames
227
  mask = df['Frame'].isin(most_frequent_person_frames)
228
 
 
233
 
234
  return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
235
 
236
+
237
  def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, most_frequent_person_frames):
238
  print(f"Creating heatmap video. Output folder: {output_folder}")
239
 
 
287
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
288
  return None
289
 
290
+
291
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
292
  frame_count = int(t * video_fps)
293