Spaces:
Runtime error
Runtime error
Update visualization.py
Browse files- visualization.py +10 -2
visualization.py
CHANGED
|
@@ -9,6 +9,7 @@ import cv2
|
|
| 9 |
from matplotlib.patches import Rectangle
|
| 10 |
from utils import seconds_to_timecode
|
| 11 |
from anomaly_detection import determine_anomalies
|
|
|
|
| 12 |
|
| 13 |
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
|
| 14 |
plt.figure(figsize=(16, 8), dpi=300)
|
|
@@ -206,7 +207,7 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
|
|
| 206 |
return fig
|
| 207 |
|
| 208 |
|
| 209 |
-
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster):
|
| 210 |
# Filter the DataFrame to only include frames from the largest cluster
|
| 211 |
df_largest_cluster = df[df['Cluster'] == largest_cluster]
|
| 212 |
mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
|
|
@@ -254,6 +255,10 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
| 254 |
plt.tight_layout()
|
| 255 |
|
| 256 |
line = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
for frame_count in range(total_frames):
|
| 258 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
|
| 259 |
ret, frame = cap.read()
|
|
@@ -262,7 +267,7 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
| 262 |
|
| 263 |
if line:
|
| 264 |
line.remove()
|
| 265 |
-
line = ax.axvline(x=frame_count, color='
|
| 266 |
|
| 267 |
canvas = FigureCanvasAgg(fig)
|
| 268 |
canvas.draw()
|
|
@@ -280,6 +285,9 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
| 280 |
cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 281 |
|
| 282 |
out.write(combined_frame)
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
cap.release()
|
| 285 |
out.release()
|
|
|
|
| 9 |
from matplotlib.patches import Rectangle
|
| 10 |
from utils import seconds_to_timecode
|
| 11 |
from anomaly_detection import determine_anomalies
|
| 12 |
+
import gradio as gr
|
| 13 |
|
| 14 |
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
|
| 15 |
plt.figure(figsize=(16, 8), dpi=300)
|
|
|
|
| 207 |
return fig
|
| 208 |
|
| 209 |
|
| 210 |
+
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster, progress=gr.Progress()):
|
| 211 |
# Filter the DataFrame to only include frames from the largest cluster
|
| 212 |
df_largest_cluster = df[df['Cluster'] == largest_cluster]
|
| 213 |
mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
|
|
|
|
| 255 |
plt.tight_layout()
|
| 256 |
|
| 257 |
line = None
|
| 258 |
+
|
| 259 |
+
# Add progress tracking
|
| 260 |
+
progress(0.9, desc="Generating video with heatmap")
|
| 261 |
+
|
| 262 |
for frame_count in range(total_frames):
|
| 263 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
|
| 264 |
ret, frame = cap.read()
|
|
|
|
| 267 |
|
| 268 |
if line:
|
| 269 |
line.remove()
|
| 270 |
+
line = ax.axvline(x=frame_count, color='blue', linewidth=3)
|
| 271 |
|
| 272 |
canvas = FigureCanvasAgg(fig)
|
| 273 |
canvas.draw()
|
|
|
|
| 285 |
cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 286 |
|
| 287 |
out.write(combined_frame)
|
| 288 |
+
|
| 289 |
+
# Update progress
|
| 290 |
+
progress(0.9 + (0.1 * (frame_count + 1) / total_frames), desc="Generating video with heatmap")
|
| 291 |
|
| 292 |
cap.release()
|
| 293 |
out.release()
|