reab5555 commited on
Commit
b79d539
·
verified ·
1 Parent(s): 219299d

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +53 -29
visualization.py CHANGED
@@ -217,27 +217,45 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
217
  plt.close()
218
  return fig
219
 
220
- def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
221
- frame_count = int(t * video_fps)
 
 
 
 
222
 
223
  # Normalize MSE values
224
- mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
225
- mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
226
- mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice))
 
 
 
 
 
 
 
 
 
 
227
 
228
- combined_mse = np.zeros((3, total_frames))
229
- combined_mse[0] = mse_embeddings_norm
230
- combined_mse[1] = mse_posture_norm
231
- combined_mse[2] = mse_voice_norm
 
 
232
 
233
  fig, ax = plt.subplots(figsize=(video_width / 240, 0.5))
234
- ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
235
  ax.set_yticks([0.5, 1.5, 2.5])
236
  ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
237
  ax.set_xticks([])
238
-
239
- ax.axvline(x=frame_count, color='black', linewidth=3)
240
-
 
 
241
  plt.tight_layout(pad=0.5)
242
 
243
  canvas = FigureCanvas(fig)
@@ -247,7 +265,7 @@ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_f
247
  plt.close(fig)
248
  return heatmap_img
249
 
250
- def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, largest_cluster):
251
  print(f"Creating heatmap video. Output folder: {output_folder}")
252
 
253
  os.makedirs(output_folder, exist_ok=True)
@@ -262,31 +280,30 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
262
 
263
  # Get video properties
264
  width, height = video.w, video.h
265
- total_frames = int(video.duration * video.fps)
 
 
 
 
 
 
 
 
 
266
 
267
- # Fill missing MSE values with 0.001
268
- def pad_with_zeros(mse_array, total_frames):
269
- if len(mse_array) < total_frames:
270
- return np.pad(mse_array, (0, total_frames - len(mse_array)), 'constant', constant_values=0.001)
271
- else:
272
- return mse_array[:total_frames]
273
-
274
- mse_embeddings = pad_with_zeros(mse_embeddings, total_frames)
275
- mse_posture = pad_with_zeros(mse_posture, total_frames)
276
- mse_voice = pad_with_zeros(mse_voice, total_frames)
277
-
278
  def combine_video_and_heatmap(t):
279
  video_frame = video.get_frame(t)
280
- heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, total_frames, width)
 
281
  heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
282
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
283
  return combined_frame
284
 
285
- final_clip = VideoClip(combine_video_and_heatmap, duration=video.duration)
286
  final_clip = final_clip.set_audio(video.audio)
287
 
288
  # Write the final video
289
- final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video.fps)
290
 
291
  # Close the video clips
292
  video.close()
@@ -300,6 +317,13 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
300
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
301
  return None
302
 
 
 
 
 
 
 
 
303
 
304
  # Function to create the correlation heatmap
305
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
 
217
  plt.close()
218
  return fig
219
 
220
+ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, analysis_fps, video_width):
221
+ video_frame_count = int(t * video_fps)
222
+ analysis_frame_count = int(t * analysis_fps)
223
+
224
+ # Ensure we don't go out of bounds
225
+ analysis_frame_count = min(analysis_frame_count, len(mse_embeddings) - 1)
226
 
227
  # Normalize MSE values
228
+ def safe_normalize(arr):
229
+ min_val, max_val = np.min(arr), np.max(arr)
230
+ if min_val == max_val:
231
+ return np.zeros_like(arr)
232
+ return (arr - min_val) / (max_val - min_val)
233
+
234
+ mse_embeddings_norm = safe_normalize(mse_embeddings)
235
+ mse_posture_norm = safe_normalize(mse_posture)
236
+ mse_voice_norm = safe_normalize(mse_voice)
237
+
238
+ # Create heatmap data
239
+ heatmap_width = int(video_width / 240 * 100) # Adjust this multiplier as needed
240
+ combined_mse = np.zeros((3, heatmap_width))
241
 
242
+ # Map analysis frames to heatmap width
243
+ for i in range(heatmap_width):
244
+ frame_index = int(i * len(mse_embeddings) / heatmap_width)
245
+ combined_mse[0, i] = mse_embeddings_norm[frame_index]
246
+ combined_mse[1, i] = mse_posture_norm[frame_index]
247
+ combined_mse[2, i] = mse_voice_norm[frame_index]
248
 
249
  fig, ax = plt.subplots(figsize=(video_width / 240, 0.5))
250
+ ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, heatmap_width, 0, 3])
251
  ax.set_yticks([0.5, 1.5, 2.5])
252
  ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
253
  ax.set_xticks([])
254
+
255
+ # Calculate the position of the vertical line
256
+ line_pos = (video_frame_count / video_fps) / (len(mse_embeddings) / analysis_fps) * heatmap_width
257
+ ax.axvline(x=line_pos, color='black', linewidth=3)
258
+
259
  plt.tight_layout(pad=0.5)
260
 
261
  canvas = FigureCanvas(fig)
 
265
  plt.close(fig)
266
  return heatmap_img
267
 
268
+ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, analysis_fps, largest_cluster):
269
  print(f"Creating heatmap video. Output folder: {output_folder}")
270
 
271
  os.makedirs(output_folder, exist_ok=True)
 
280
 
281
  # Get video properties
282
  width, height = video.w, video.h
283
+ video_duration = video.duration
284
+ video_fps = video.fps
285
+
286
+ # Calculate the number of analysis frames
287
+ analysis_frames = int(video_duration * analysis_fps)
288
+
289
+ # Ensure MSE arrays match the number of analysis frames
290
+ mse_embeddings = pad_or_trim(mse_embeddings, analysis_frames)
291
+ mse_posture = pad_or_trim(mse_posture, analysis_frames)
292
+ mse_voice = pad_or_trim(mse_voice, analysis_frames)
293
 
 
 
 
 
 
 
 
 
 
 
 
294
  def combine_video_and_heatmap(t):
295
  video_frame = video.get_frame(t)
296
+ analysis_frame = int(t * analysis_fps)
297
+ heatmap_frame = create_heatmap(analysis_frame, mse_embeddings, mse_posture, mse_voice, analysis_fps, analysis_frames, width)
298
  heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
299
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
300
  return combined_frame
301
 
302
+ final_clip = VideoClip(combine_video_and_heatmap, duration=video_duration)
303
  final_clip = final_clip.set_audio(video.audio)
304
 
305
  # Write the final video
306
+ final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video_fps)
307
 
308
  # Close the video clips
309
  video.close()
 
317
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
318
  return None
319
 
320
+ def pad_or_trim(mse_array, target_length):
321
+ if len(mse_array) < target_length:
322
+ return np.pad(mse_array, (0, target_length - len(mse_array)), 'constant', constant_values=0)
323
+ else:
324
+ return mse_array[:target_length]
325
+
326
+
327
 
328
  # Function to create the correlation heatmap
329
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):