reab5555 commited on
Commit
4680051
·
verified ·
1 Parent(s): 6732975

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +10 -20
visualization.py CHANGED
@@ -202,59 +202,49 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
202
 
203
 
204
  def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path):
205
- from matplotlib.backends.backend_agg import FigureCanvasAgg
206
- import cv2
207
- # Open the video
208
  cap = cv2.VideoCapture(video_path)
209
  fps = cap.get(cv2.CAP_PROP_FPS)
210
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
211
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
212
 
213
- # Create the output video writer
214
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
215
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 200)) # Additional 200 pixels for heatmap
216
 
217
- # Prepare the heatmap data
218
- heatmap_data = np.vstack((mse_embeddings, mse_posture))
 
 
219
 
220
- # Create a figure for the heatmap
221
  fig, ax = plt.subplots(figsize=(width/100, 2))
222
- im = ax.imshow(heatmap_data, aspect='auto', cmap='YlOrRd')
223
  ax.set_yticks([])
224
  ax.set_xticks([])
225
  plt.tight_layout()
226
 
227
- frame_count = 0
228
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
229
-
230
- line = None # Initialize line object
231
 
232
- while True:
233
  ret, frame = cap.read()
234
  if not ret:
235
  break
236
 
237
- # Update the heatmap with the current frame position
238
  if line:
239
- line.remove() # Remove the previous line
240
  line = ax.axvline(x=frame_count, color='r', linewidth=2)
241
 
242
- # Convert the matplotlib figure to an image
243
  canvas = FigureCanvasAgg(fig)
244
  canvas.draw()
245
  heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
246
  heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
247
  heatmap_img = cv2.resize(heatmap_img, (width, 200))
248
 
249
- # Combine the video frame and the heatmap
250
  combined_frame = np.vstack((frame, heatmap_img))
251
 
252
- # Add timecode to the frame
253
  timecode = df['Timecode'][frame_count] if frame_count < len(df) else "End"
254
  cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
255
 
256
  out.write(combined_frame)
257
- frame_count += 1
258
 
259
  cap.release()
260
  out.release()
 
202
 
203
 
204
  def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path):
 
 
 
205
  cap = cv2.VideoCapture(video_path)
206
  fps = cap.get(cv2.CAP_PROP_FPS)
207
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
208
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
209
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
210
 
 
211
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
212
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 200))
213
 
214
+ # Ensure heatmap data covers all frames
215
+ heatmap_data = np.zeros((2, total_frames))
216
+ heatmap_data[0, :len(mse_embeddings)] = mse_embeddings
217
+ heatmap_data[1, :len(mse_posture)] = mse_posture
218
 
 
219
  fig, ax = plt.subplots(figsize=(width/100, 2))
220
+ im = ax.imshow(heatmap_data, aspect='auto', cmap='YlOrRd', extent=[0, total_frames, 0, 2])
221
  ax.set_yticks([])
222
  ax.set_xticks([])
223
  plt.tight_layout()
224
 
225
+ line = None
 
 
 
226
 
227
+ for frame_count in range(total_frames):
228
  ret, frame = cap.read()
229
  if not ret:
230
  break
231
 
 
232
  if line:
233
+ line.remove()
234
  line = ax.axvline(x=frame_count, color='r', linewidth=2)
235
 
 
236
  canvas = FigureCanvasAgg(fig)
237
  canvas.draw()
238
  heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
239
  heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
240
  heatmap_img = cv2.resize(heatmap_img, (width, 200))
241
 
 
242
  combined_frame = np.vstack((frame, heatmap_img))
243
 
 
244
  timecode = df['Timecode'][frame_count] if frame_count < len(df) else "End"
245
  cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
246
 
247
  out.write(combined_frame)
 
248
 
249
  cap.release()
250
  out.release()