reab5555 commited on
Commit
b6d52dc
·
verified ·
1 Parent(s): 0b9bf9e

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +10 -2
visualization.py CHANGED
@@ -9,6 +9,7 @@ import cv2
9
  from matplotlib.patches import Rectangle
10
  from utils import seconds_to_timecode
11
  from anomaly_detection import determine_anomalies
 
12
 
13
  def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
14
  plt.figure(figsize=(16, 8), dpi=300)
@@ -206,7 +207,7 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
206
  return fig
207
 
208
 
209
- def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster):
210
  # Filter the DataFrame to only include frames from the largest cluster
211
  df_largest_cluster = df[df['Cluster'] == largest_cluster]
212
  mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
@@ -254,6 +255,10 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
254
  plt.tight_layout()
255
 
256
  line = None
 
 
 
 
257
  for frame_count in range(total_frames):
258
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
259
  ret, frame = cap.read()
@@ -262,7 +267,7 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
262
 
263
  if line:
264
  line.remove()
265
- line = ax.axvline(x=frame_count, color='r', linewidth=2)
266
 
267
  canvas = FigureCanvasAgg(fig)
268
  canvas.draw()
@@ -280,6 +285,9 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
280
  cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
281
 
282
  out.write(combined_frame)
 
 
 
283
 
284
  cap.release()
285
  out.release()
 
9
  from matplotlib.patches import Rectangle
10
  from utils import seconds_to_timecode
11
  from anomaly_detection import determine_anomalies
12
+ import gradio as gr
13
 
14
  def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
15
  plt.figure(figsize=(16, 8), dpi=300)
 
207
  return fig
208
 
209
 
210
+ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster, progress=gr.Progress()):
211
  # Filter the DataFrame to only include frames from the largest cluster
212
  df_largest_cluster = df[df['Cluster'] == largest_cluster]
213
  mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
 
255
  plt.tight_layout()
256
 
257
  line = None
258
+
259
+ # Add progress tracking
260
+ progress(0.9, desc="Generating video with heatmap")
261
+
262
  for frame_count in range(total_frames):
263
  cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
264
  ret, frame = cap.read()
 
267
 
268
  if line:
269
  line.remove()
270
+ line = ax.axvline(x=frame_count, color='blue', linewidth=3)
271
 
272
  canvas = FigureCanvasAgg(fig)
273
  canvas.draw()
 
285
  cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
286
 
287
  out.write(combined_frame)
288
+
289
+ # Update progress
290
+ progress(0.9 + (0.1 * (frame_count + 1) / total_frames), desc="Generating video with heatmap")
291
 
292
  cap.release()
293
  out.release()