reab5555 commited on
Commit
4d4ae71
·
verified ·
1 Parent(s): 9ad3ed2

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +60 -0
visualization.py CHANGED
@@ -199,3 +199,63 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
199
  plt.tight_layout()
200
  plt.close()
201
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  plt.tight_layout()
200
  plt.close()
201
  return fig
202
+
203
+
204
+ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_path):
205
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
206
+ # Open the video
207
+ cap = cv2.VideoCapture(video_path)
208
+ fps = cap.get(cv2.CAP_PROP_FPS)
209
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
210
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
211
+
212
+ # Create the output video writer
213
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
214
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 200)) # Additional 200 pixels for heatmap
215
+
216
+ # Prepare the heatmap data
217
+ heatmap_data = np.vstack((mse_embeddings, mse_posture, mse_voice))
218
+
219
+ # Create a figure for the heatmap
220
+ fig, ax = plt.subplots(figsize=(width/100, 2))
221
+ im = ax.imshow(heatmap_data, aspect='auto', cmap='YlOrRd')
222
+ ax.set_yticks([])
223
+ ax.set_xticks([])
224
+ plt.tight_layout()
225
+
226
+ frame_count = 0
227
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
228
+
229
+ while True:
230
+ ret, frame = cap.read()
231
+ if not ret:
232
+ break
233
+
234
+ # Update the heatmap with the current frame position
235
+ ax.axvline(x=frame_count, color='r', linewidth=2)
236
+
237
+ # Convert the matplotlib figure to an image
238
+ canvas = FigureCanvasAgg(fig)
239
+ canvas.draw()
240
+ heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
241
+ heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
242
+ heatmap_img = cv2.resize(heatmap_img, (width, 200))
243
+
244
+ # Combine the video frame and the heatmap
245
+ combined_frame = np.vstack((frame, heatmap_img))
246
+
247
+ # Add timecode to the frame
248
+ timecode = df['Timecode'][frame_count] if frame_count < len(df) else "End"
249
+ cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
250
+
251
+ out.write(combined_frame)
252
+ frame_count += 1
253
+
254
+ # Remove the vertical line for the next iteration
255
+ ax.lines.pop()
256
+
257
+ cap.release()
258
+ out.release()
259
+ plt.close(fig)
260
+
261
+ return output_path