reab5555 commited on
Commit
869705c
·
verified ·
1 Parent(s): 32c0667

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +25 -18
visualization.py CHANGED
@@ -227,7 +227,7 @@ def fill_with_zeros(mse_array, total_frames):
227
 
228
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
229
  frame_count = int(t * desired_fps)
230
- window_size = min(300, total_frames)
231
  start_frame = max(0, frame_count - window_size // 2)
232
  end_frame = min(total_frames, start_frame + window_size)
233
 
@@ -237,17 +237,25 @@ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total
237
  mse_voice[start_frame:end_frame]
238
  ])
239
 
240
- fig, ax = plt.subplots(figsize=(video_width / 300, 0.4))
 
 
 
 
241
  im = ax.imshow(combined_mse, aspect='auto', cmap='Reds',
242
  extent=[start_frame/desired_fps, end_frame/desired_fps, 0, 3],
243
- vmin=0, vmax=max(np.max(mse_embeddings), np.max(mse_posture), np.max(mse_voice)))
 
244
  ax.set_yticks([0.5, 1.5, 2.5])
245
- ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
246
 
247
- ax.axvline(x=t, color='black', linewidth=2)
 
 
248
 
249
- ax.set_xticks([start_frame/desired_fps, t, end_frame/desired_fps])
250
- ax.set_xticklabels([f'{start_frame/desired_fps:.2f}', f'{t:.2f}', f'{end_frame/desired_fps:.2f}'], fontsize=6)
 
251
 
252
  plt.tight_layout(pad=0.5)
253
 
@@ -272,23 +280,22 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
272
  width, height = video.w, video.h
273
  total_frames = int(video.duration * desired_fps)
274
 
275
- # Fill gaps with zeros
276
- mse_embeddings = fill_with_zeros(mse_embeddings, total_frames)
277
- mse_posture = fill_with_zeros(mse_posture, total_frames)
278
- mse_voice = fill_with_zeros(mse_voice, total_frames)
 
 
 
 
 
279
 
280
  def combine_video_and_heatmap(t):
281
  original_frame = int(t * video.fps)
282
  video_frame = video.get_frame(original_frame / video.fps)
283
 
284
  heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, width)
285
- heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
286
-
287
- if video_frame.shape[2] != heatmap_frame_resized.shape[2]:
288
- if video_frame.shape[2] == 3:
289
- heatmap_frame_resized = heatmap_frame_resized[:, :, :3]
290
- else:
291
- video_frame = cv2.cvtColor(video_frame, cv2.COLOR_RGB2RGBA)
292
 
293
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
294
  return combined_frame
 
227
 
228
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, video_width):
229
  frame_count = int(t * desired_fps)
230
+ window_size = min(600, total_frames) # Increased window size for better context
231
  start_frame = max(0, frame_count - window_size // 2)
232
  end_frame = min(total_frames, start_frame + window_size)
233
 
 
237
  mse_voice[start_frame:end_frame]
238
  ])
239
 
240
+ # Calculate global min and max for consistent scaling
241
+ vmin = 0
242
+ vmax = max(np.max(mse_embeddings), np.max(mse_posture), np.max(mse_voice))
243
+
244
+ fig, ax = plt.subplots(figsize=(video_width / 100, 0.4)) # Adjusted figure size
245
  im = ax.imshow(combined_mse, aspect='auto', cmap='Reds',
246
  extent=[start_frame/desired_fps, end_frame/desired_fps, 0, 3],
247
+ vmin=vmin, vmax=vmax, interpolation='nearest')
248
+
249
  ax.set_yticks([0.5, 1.5, 2.5])
250
+ ax.set_yticklabels(['Face', 'Posture', 'Voice'], fontsize=7)
251
 
252
+ # Add vertical line for current time
253
+ current_time = t
254
+ ax.axvline(x=current_time, color='black', linewidth=2)
255
 
256
+ # Set x-axis ticks and labels
257
+ ax.set_xticks([start_frame/desired_fps, current_time, end_frame/desired_fps])
258
+ ax.set_xticklabels([f'{start_frame/desired_fps:.2f}', f'{current_time:.2f}', f'{end_frame/desired_fps:.2f}'], fontsize=6)
259
 
260
  plt.tight_layout(pad=0.5)
261
 
 
280
  width, height = video.w, video.h
281
  total_frames = int(video.duration * desired_fps)
282
 
283
+ # Interpolate MSE values to match the desired fps
284
+ def interpolate_mse(mse_array):
285
+ original_indices = np.linspace(0, total_frames - 1, len(mse_array))
286
+ new_indices = np.arange(total_frames)
287
+ return np.interp(new_indices, original_indices, mse_array)
288
+
289
+ mse_embeddings = interpolate_mse(mse_embeddings)
290
+ mse_posture = interpolate_mse(mse_posture)
291
+ mse_voice = interpolate_mse(mse_voice)
292
 
293
  def combine_video_and_heatmap(t):
294
  original_frame = int(t * video.fps)
295
  video_frame = video.get_frame(original_frame / video.fps)
296
 
297
  heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, desired_fps, total_frames, width)
298
+ heatmap_frame_resized = cv2.resize(heatmap_frame, (width, int(height * 0.2)))
 
 
 
 
 
 
299
 
300
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
301
  return combined_frame