reab5555 commited on
Commit
6c4ec2c
·
verified ·
1 Parent(s): 51a8f17

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +33 -98
visualization.py CHANGED
@@ -218,104 +218,6 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
218
  return fig
219
 
220
 
221
-
222
- def fill_with_zeros(mse_array, total_frames):
223
- result = np.zeros(total_frames)
224
- indices = np.linspace(0, total_frames - 1, len(mse_array)).astype(int)
225
- result[indices] = mse_array
226
- return result
227
-
228
- def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
229
- fig, ax = plt.subplots(figsize=(video_width / 250, 0.4))
230
-
231
- # Create the full heatmap for the entire video duration
232
- combined_mse = np.array([mse_embeddings, mse_posture, mse_voice])
233
-
234
- # Use pcolormesh for better performance with large datasets
235
- im = ax.pcolormesh(np.arange(total_frames) / desired_fps, [0, 1, 2], combined_mse,
236
- cmap='Reds', vmin=0, vmax=np.max(combined_mse))
237
-
238
- ax.set_ylim(0, 3)
239
- ax.set_yticks([0.5, 1.5, 2.5])
240
- ax.set_yticklabels(['Face', 'Posture', 'Voice'], fontsize=7)
241
-
242
- # Set x-axis to show full video duration
243
- ax.set_xlim(0, total_frames / desired_fps)
244
-
245
- # Add vertical line for current time
246
- current_time = t
247
- ax.axvline(x=current_time, color='black', linewidth=2)
248
-
249
- # Set x-axis ticks and labels
250
- ax.set_xticks([0, current_time, total_frames / desired_fps])
251
- ax.set_xticklabels(['0:00', f'{current_time:.2f}', f'{total_frames / desired_fps:.2f}'], fontsize=6)
252
-
253
- plt.tight_layout(pad=0.5)
254
-
255
- canvas = FigureCanvas(fig)
256
- canvas.draw()
257
- heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
258
- heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
259
- plt.close(fig)
260
- return heatmap_img
261
-
262
- def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, largest_cluster):
263
- print(f"Creating heatmap video. Output folder: {output_folder}")
264
- os.makedirs(output_folder, exist_ok=True)
265
- output_filename = os.path.basename(video_path).rsplit('.', 1)[0] + '_heatmap.mp4'
266
- heatmap_video_path = os.path.join(output_folder, output_filename)
267
- print(f"Heatmap video will be saved at: {heatmap_video_path}")
268
-
269
- # Load the original video
270
- video = VideoFileClip(video_path)
271
-
272
- # Get video properties
273
- width, height = video.w, video.h
274
- total_frames = int(video.duration * video.fps)
275
-
276
- # Ensure MSE arrays align with original video frames
277
- def align_mse_array(mse_array, original_fps, desired_fps, total_frames):
278
- original_times = np.arange(len(mse_array)) / original_fps
279
- desired_times = np.arange(total_frames) / desired_fps
280
- interpolated_mse = np.interp(desired_times, original_times, mse_array)
281
- return interpolated_mse
282
-
283
- original_fps = len(mse_embeddings) / video.duration
284
- mse_embeddings = align_mse_array(mse_embeddings, original_fps, desired_fps, total_frames)
285
- mse_posture = align_mse_array(mse_posture, original_fps, desired_fps, total_frames)
286
- mse_voice = align_mse_array(mse_voice, original_fps, desired_fps, total_frames)
287
-
288
- def combine_video_and_heatmap(t):
289
- frame_index = int(t * desired_fps)
290
- video_frame = video.get_frame(t)
291
-
292
- heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, width)
293
- heatmap_frame_resized = cv2.resize(heatmap_frame, (width, int(height * 0.2)))
294
-
295
- combined_frame = np.vstack((video_frame, heatmap_frame_resized))
296
- return combined_frame
297
-
298
- final_clip = VideoClip(combine_video_and_heatmap, duration=video.duration)
299
- final_clip = final_clip.set_fps(desired_fps)
300
-
301
- if video.audio is not None:
302
- final_clip = final_clip.set_audio(video.audio.set_fps(desired_fps))
303
-
304
- final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=desired_fps)
305
-
306
- # Close the video clips
307
- video.close()
308
- final_clip.close()
309
-
310
- if os.path.exists(heatmap_video_path):
311
- print(f"Heatmap video created at: {heatmap_video_path}")
312
- print(f"Heatmap video size: {os.path.getsize(heatmap_video_path)} bytes")
313
- return heatmap_video_path
314
- else:
315
- print(f"Failed to create heatmap video at: {heatmap_video_path}")
316
- return None
317
-
318
-
319
  # Function to create the correlation heatmap
320
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
321
  data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
@@ -328,3 +230,36 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
328
  plt.title('Correlation Heatmap of MSEs')
329
  plt.tight_layout()
330
  return plt.gcf()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  return fig
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # Function to create the correlation heatmap
222
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
223
  data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
 
230
  plt.title('Correlation Heatmap of MSEs')
231
  plt.tight_layout()
232
  return plt.gcf()
233
+
234
+ def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Stacked MSE Heatmaps"):
235
+ plt.figure(figsize=(20, 9), dpi=300)
236
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 9), sharex=True)
237
+
238
+ # Face heatmap
239
+ sns.heatmap(mse_face.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax1)
240
+ ax1.set_yticks([0.5])
241
+ ax1.set_yticklabels(['Face'], rotation=0, va='center')
242
+ ax1.set_xticks([])
243
+
244
+ # Posture heatmap
245
+ sns.heatmap(mse_posture.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax2)
246
+ ax2.set_yticks([0.5])
247
+ ax2.set_yticklabels(['Posture'], rotation=0, va='center')
248
+ ax2.set_xticks([])
249
+
250
+ # Voice heatmap
251
+ sns.heatmap(mse_voice.reshape(1, -1), cmap='YlOrRd', cbar=False, ax=ax3)
252
+ ax3.set_yticks([0.5])
253
+ ax3.set_yticklabels(['Voice'], rotation=0, va='center')
254
+
255
+ # Set x-axis ticks to timecodes for the bottom subplot
256
+ num_ticks = min(60, len(mse_voice))
257
+ tick_locations = np.linspace(0, len(mse_voice) - 1, num_ticks).astype(int)
258
+ tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations]
259
+ ax3.set_xticks(tick_locations)
260
+ ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
261
+
262
+ plt.suptitle(title)
263
+ plt.tight_layout()
264
+ plt.close()
265
+ return fig