reab5555 commited on
Commit
ae7c4bd
·
verified ·
1 Parent(s): 8e8d1ed

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +35 -25
visualization.py CHANGED
@@ -296,35 +296,45 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
296
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
297
  return None
298
 
299
- def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
300
- frame_count = int(t * video_fps)
301
-
302
- # Normalize MSE values
303
- mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
304
- mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
305
- mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice))
306
-
307
- combined_mse = np.zeros((3, total_frames))
 
 
 
 
 
 
 
 
 
 
308
  combined_mse[0] = mse_embeddings_norm
 
309
  combined_mse[1] = mse_posture_norm
310
  combined_mse[2] = mse_voice_norm
311
 
312
- fig, ax = plt.subplots(figsize=(video_width / 250, 0.6))
313
- ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
314
- ax.set_yticks([0.5, 1.5, 2.5])
315
- ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
316
- ax.set_xticks([])
317
-
318
- ax.axvline(x=frame_count, color='black', linewidth=3)
319
-
320
- plt.tight_layout(pad=0.5)
321
-
322
- canvas = FigureCanvas(fig)
323
- canvas.draw()
324
- heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
325
- heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
326
- plt.close(fig)
327
- return heatmap_img
328
 
329
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
330
  data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
 
296
  print(f"Failed to create heatmap video at: {heatmap_video_path}")
297
  return None
298
 
299
+ def create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, fps, total_frames, width):
300
+ # Normalize the MSE values
301
+ mse_embeddings_norm = normalize_mse(mse_embeddings_filtered)
302
+ mse_posture_norm = normalize_mse(mse_posture_filtered)
303
+ mse_voice_norm = normalize_mse(mse_voice_filtered)
304
+
305
+ # Debug prints
306
+ print(f"mse_embeddings_norm shape: {mse_embeddings_norm.shape}")
307
+ print(f"mse_posture_norm shape: {mse_posture_norm.shape}")
308
+ print(f"mse_voice_norm shape: {mse_voice_norm.shape}")
309
+
310
+ # Ensure combined_mse has the correct shape
311
+ combined_mse = np.zeros((total_frames, width))
312
+
313
+ # Adjust shapes and pad with zeros if necessary
314
+ mse_embeddings_norm = pad_or_trim_array(mse_embeddings_norm, width)
315
+ mse_posture_norm = pad_or_trim_array(mse_posture_norm, width)
316
+ mse_voice_norm = pad_or_trim_array(mse_voice_norm, width)
317
+
318
  combined_mse[0] = mse_embeddings_norm
319
+ # Assuming you combine posture and voice MSEs similarly
320
  combined_mse[1] = mse_posture_norm
321
  combined_mse[2] = mse_voice_norm
322
 
323
+ # Return or use combined_mse as needed
324
+ return combined_mse
325
+
326
+ def normalize_mse(mse):
327
+ # Your normalization logic here
328
+ return mse / np.max(mse)
329
+
330
+ def pad_or_trim_array(arr, target_length):
331
+ if len(arr) > target_length:
332
+ # Trim the array
333
+ return arr[:target_length]
334
+ elif len(arr) < target_length:
335
+ # Pad the array with zeros
336
+ return np.pad(arr, (0, target_length - len(arr)), 'constant')
337
+ return arr
 
338
 
339
  def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
340
  data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T