reab5555 commited on
Commit
1db5ddf
·
verified ·
1 Parent(s): 044329c

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +58 -71
visualization.py CHANGED
@@ -216,56 +216,60 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
216
  plt.tight_layout()
217
  plt.close()
218
  return fig
219
-
220
- def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, analysis_fps, video_width):
221
- video_frame_count = int(t * video_fps)
222
- analysis_frame_count = int(t * analysis_fps)
223
-
224
- # Ensure we don't go out of bounds
225
- analysis_frame_count = min(analysis_frame_count, len(mse_embeddings) - 1)
226
-
227
- # Normalize MSE values
228
- def safe_normalize(arr):
229
- min_val, max_val = np.min(arr), np.max(arr)
230
- if min_val == max_val:
231
- return np.zeros_like(arr)
232
- return (arr - min_val) / (max_val - min_val)
233
-
234
- mse_embeddings_norm = safe_normalize(mse_embeddings)
235
- mse_posture_norm = safe_normalize(mse_posture)
236
- mse_voice_norm = safe_normalize(mse_voice)
237
-
238
- # Create heatmap data
239
- heatmap_width = int(video_width / 240 * 100) # Adjust this multiplier as needed
240
- combined_mse = np.zeros((3, heatmap_width))
241
-
242
- # Map analysis frames to heatmap width
243
- for i in range(heatmap_width):
244
- frame_index = int(i * len(mse_embeddings) / heatmap_width)
245
- combined_mse[0, i] = mse_embeddings_norm[frame_index]
246
- combined_mse[1, i] = mse_posture_norm[frame_index]
247
- combined_mse[2, i] = mse_voice_norm[frame_index]
248
-
249
- fig, ax = plt.subplots(figsize=(video_width / 240, 0.5))
250
- ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, heatmap_width, 0, 3])
251
- ax.set_yticks([0.5, 1.5, 2.5])
252
- ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
253
- ax.set_xticks([])
254
-
255
- # Calculate the position of the vertical line
256
- line_pos = (video_frame_count / video_fps) / (len(mse_embeddings) / analysis_fps) * heatmap_width
257
- ax.axvline(x=line_pos, color='black', linewidth=3)
258
-
259
- plt.tight_layout(pad=0.5)
260
-
261
- canvas = FigureCanvas(fig)
262
- canvas.draw()
263
- heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
264
- heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
265
- plt.close(fig)
266
- return heatmap_img
267
-
268
- def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, analysis_fps, largest_cluster):
 
 
 
 
269
  print(f"Creating heatmap video. Output folder: {output_folder}")
270
 
271
  os.makedirs(output_folder, exist_ok=True)
@@ -280,30 +284,19 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
280
 
281
  # Get video properties
282
  width, height = video.w, video.h
283
- video_duration = video.duration
284
- video_fps = video.fps
285
-
286
- # Calculate the number of analysis frames
287
- analysis_frames = int(video_duration * analysis_fps)
288
-
289
- # Ensure MSE arrays match the number of analysis frames
290
- mse_embeddings = pad_or_trim(mse_embeddings, analysis_frames)
291
- mse_posture = pad_or_trim(mse_posture, analysis_frames)
292
- mse_voice = pad_or_trim(mse_voice, analysis_frames)
293
 
294
  def combine_video_and_heatmap(t):
295
  video_frame = video.get_frame(t)
296
- analysis_frame = int(t * analysis_fps)
297
- heatmap_frame = create_heatmap(analysis_frame, mse_embeddings, mse_posture, mse_voice, analysis_fps, analysis_frames, width)
298
- heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
299
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
300
  return combined_frame
301
 
302
- final_clip = VideoClip(combine_video_and_heatmap, duration=video_duration)
303
  final_clip = final_clip.set_audio(video.audio)
304
 
305
  # Write the final video
306
- final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video_fps)
307
 
308
  # Close the video clips
309
  video.close()
@@ -317,12 +310,6 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
317
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
318
  return None
319
 
320
- def pad_or_trim(mse_array, target_length):
321
- if len(mse_array) < target_length:
322
- return np.pad(mse_array, (0, target_length - len(mse_array)), 'constant', constant_values=0)
323
- else:
324
- return mse_array[:target_length]
325
-
326
 
327
 
328
  # Function to create the correlation heatmap
 
216
  plt.tight_layout()
217
  plt.close()
218
  return fig
219
+ def create_combined_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, analysis_fps, video_width):
220
+ def plot_single_mse_heatmap(mse_values, height=1):
221
+ plt.figure(figsize=(video_width / 100, height), dpi=100)
222
+ fig, ax = plt.subplots(figsize=(video_width / 100, height))
223
+
224
+ # Reshape MSE values to 2D array for heatmap
225
+ mse_2d = mse_values.reshape(1, -1)
226
+
227
+ # Create heatmap
228
+ sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax)
229
+
230
+ # Remove all axes
231
+ ax.set_xticks([])
232
+ ax.set_yticks([])
233
+ ax.axis('off')
234
+
235
+ plt.tight_layout(pad=0)
236
+
237
+ # Convert plot to image
238
+ canvas = FigureCanvas(fig)
239
+ canvas.draw()
240
+ image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
241
+ image = image.reshape(canvas.get_width_height()[::-1] + (3,))
242
+ plt.close(fig)
243
+ return image
244
+
245
+ # Create individual heatmaps
246
+ face_heatmap = plot_single_mse_heatmap(mse_embeddings)
247
+ posture_heatmap = plot_single_mse_heatmap(mse_posture)
248
+ voice_heatmap = plot_single_mse_heatmap(mse_voice)
249
+
250
+ # Combine heatmaps vertically
251
+ combined_heatmap = np.vstack((face_heatmap, posture_heatmap, voice_heatmap))
252
+
253
+ # Add labels
254
+ label_height = 20
255
+ label_image = np.ones((label_height, combined_heatmap.shape[1], 3), dtype=np.uint8) * 255
256
+ cv2.putText(label_image, 'Face', (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
257
+ cv2.putText(label_image, 'Posture', (5, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
258
+ cv2.putText(label_image, 'Voice', (5, 55), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
259
+
260
+ combined_heatmap = np.vstack((label_image, combined_heatmap))
261
+
262
+ # Calculate position of vertical line
263
+ video_frame = int(t * video_fps)
264
+ total_analysis_frames = len(mse_embeddings)
265
+ line_pos = int((video_frame / (video_fps / analysis_fps)) * combined_heatmap.shape[1] / total_analysis_frames)
266
+
267
+ # Draw vertical line
268
+ cv2.line(combined_heatmap, (line_pos, 0), (line_pos, combined_heatmap.shape[0]), (0, 0, 0), 2)
269
+
270
+ return combined_heatmap
271
+
272
+ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, analysis_fps):
273
  print(f"Creating heatmap video. Output folder: {output_folder}")
274
 
275
  os.makedirs(output_folder, exist_ok=True)
 
284
 
285
  # Get video properties
286
  width, height = video.w, video.h
 
 
 
 
 
 
 
 
 
 
287
 
288
  def combine_video_and_heatmap(t):
289
  video_frame = video.get_frame(t)
290
+ heatmap_frame = create_combined_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, analysis_fps, width)
291
+ heatmap_frame_resized = cv2.resize(heatmap_frame, (width, int(height * 0.2)))
 
292
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
293
  return combined_frame
294
 
295
+ final_clip = VideoClip(combine_video_and_heatmap, duration=video.duration)
296
  final_clip = final_clip.set_audio(video.audio)
297
 
298
  # Write the final video
299
+ final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video.fps)
300
 
301
  # Close the video clips
302
  video.close()
 
310
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
311
  return None
312
 
 
 
 
 
 
 
313
 
314
 
315
  # Function to create the correlation heatmap