reab5555 commited on
Commit
3ca4e17
·
verified ·
1 Parent(s): 8757665

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +43 -10
visualization.py CHANGED
@@ -211,7 +211,7 @@ def fig_to_img(fig):
211
  plt.close(fig)
212
  return img
213
 
214
- def create_video_with_heatmap(video_path, mse_heatmap_embeddings_img, mse_heatmap_posture_img, output_path, desired_fps):
215
  cap = cv2.VideoCapture(video_path)
216
  original_fps = cap.get(cv2.CAP_PROP_FPS)
217
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -219,30 +219,63 @@ def create_video_with_heatmap(video_path, mse_heatmap_embeddings_img, mse_heatma
219
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
220
 
221
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
222
- out = cv2.VideoWriter(output_path, fourcc, desired_fps, (width, height + 400))
223
 
224
- # Resize heatmap images to match the width of the video frames
225
- mse_heatmap_embeddings_img = cv2.resize(mse_heatmap_embeddings_img, (width, 200))
226
- mse_heatmap_posture_img = cv2.resize(mse_heatmap_posture_img, (width, 200))
227
 
228
- for frame_count in range(0, total_frames):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
230
  ret, frame = cap.read()
231
  if not ret:
232
  break
233
 
234
- heatmap_combined_img = np.vstack((mse_heatmap_embeddings_img, mse_heatmap_posture_img))
 
 
235
 
236
- combined_frame = np.vstack((frame, heatmap_combined_img))
 
 
 
 
 
 
237
 
238
  seconds = frame_count / original_fps
239
  timecode = f"{int(seconds//3600):02d}:{int((seconds%3600)//60):02d}:{int(seconds%60):02d}"
240
- cv2.putText(combined_frame, f"Time: {timecode}", (10, height + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
241
 
242
  out.write(combined_frame)
243
 
244
  cap.release()
245
  out.release()
246
- plt.close('all')
247
 
248
  return output_path
 
211
  plt.close(fig)
212
  return img
213
 
214
+ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps):
215
  cap = cv2.VideoCapture(video_path)
216
  original_fps = cap.get(cv2.CAP_PROP_FPS)
217
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
219
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
220
 
221
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
222
+ out = cv2.VideoWriter(output_path, fourcc, desired_fps, (width, height + 200))
223
 
224
+ # Create custom colormap
225
+ cmap = mcolors.LinearSegmentedColormap.from_list("custom",
226
+ [(1, 1, 1), (0, 0, 1), (0.5, 0, 0.5)], N=256)
227
 
228
+ # Ensure heatmap data covers all frames
229
+ mse_embeddings = np.interp(np.linspace(0, len(mse_embeddings) - 1, total_frames),
230
+ np.arange(len(mse_embeddings)), mse_embeddings)
231
+ mse_posture = np.interp(np.linspace(0, len(mse_posture) - 1, total_frames),
232
+ np.arange(len(mse_posture)), mse_posture)
233
+
234
+ # Normalize MSE values
235
+ mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
236
+ mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
237
+
238
+ # Combine MSEs
239
+ combined_mse = np.zeros((2, total_frames, 3))
240
+ combined_mse[0] = np.array([1 - mse_embeddings_norm, 1 - mse_embeddings_norm, mse_embeddings_norm]).T # RGB for facial
241
+ combined_mse[1] = np.array([1 - mse_posture_norm, mse_posture_norm, 1 - mse_posture_norm]).T # RGB for posture
242
+
243
+ fig, ax = plt.subplots(figsize=(width/100, 2))
244
+ im = ax.imshow(combined_mse, aspect='auto', extent=[0, total_frames, 0, 2])
245
+ ax.set_yticks([0.5, 1.5])
246
+ ax.set_yticklabels(['Face', 'Posture'])
247
+ ax.set_xticks([])
248
+ plt.tight_layout()
249
+
250
+ line = None
251
+ frame_interval = int(original_fps / desired_fps)
252
+
253
+ for frame_count in range(0, total_frames, frame_interval):
254
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
255
  ret, frame = cap.read()
256
  if not ret:
257
  break
258
 
259
+ if line:
260
+ line.remove()
261
+ line = ax.axvline(x=frame_count, color='r', linewidth=2)
262
 
263
+ canvas = FigureCanvasAgg(fig)
264
+ canvas.draw()
265
+ heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
266
+ heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
267
+ heatmap_img = cv2.resize(heatmap_img, (width, 200))
268
+
269
+ combined_frame = np.vstack((frame, heatmap_img))
270
 
271
  seconds = frame_count / original_fps
272
  timecode = f"{int(seconds//3600):02d}:{int((seconds%3600)//60):02d}:{int(seconds%60):02d}"
273
+ cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
274
 
275
  out.write(combined_frame)
276
 
277
  cap.release()
278
  out.release()
279
+ plt.close(fig)
280
 
281
  return output_path