Update visualization.py
Browse files- visualization.py +10 -2
visualization.py
CHANGED
@@ -9,6 +9,7 @@ import cv2
|
|
9 |
from matplotlib.patches import Rectangle
|
10 |
from utils import seconds_to_timecode
|
11 |
from anomaly_detection import determine_anomalies
|
|
|
12 |
|
13 |
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
|
14 |
plt.figure(figsize=(16, 8), dpi=300)
|
@@ -206,7 +207,7 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
|
|
206 |
return fig
|
207 |
|
208 |
|
209 |
-
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster):
|
210 |
# Filter the DataFrame to only include frames from the largest cluster
|
211 |
df_largest_cluster = df[df['Cluster'] == largest_cluster]
|
212 |
mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
|
@@ -254,6 +255,10 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
254 |
plt.tight_layout()
|
255 |
|
256 |
line = None
|
|
|
|
|
|
|
|
|
257 |
for frame_count in range(total_frames):
|
258 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
|
259 |
ret, frame = cap.read()
|
@@ -262,7 +267,7 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
262 |
|
263 |
if line:
|
264 |
line.remove()
|
265 |
-
line = ax.axvline(x=frame_count, color='
|
266 |
|
267 |
canvas = FigureCanvasAgg(fig)
|
268 |
canvas.draw()
|
@@ -280,6 +285,9 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, outpu
|
|
280 |
cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
281 |
|
282 |
out.write(combined_frame)
|
|
|
|
|
|
|
283 |
|
284 |
cap.release()
|
285 |
out.release()
|
|
|
9 |
from matplotlib.patches import Rectangle
|
10 |
from utils import seconds_to_timecode
|
11 |
from anomaly_detection import determine_anomalies
|
12 |
+
import gradio as gr
|
13 |
|
14 |
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
|
15 |
plt.figure(figsize=(16, 8), dpi=300)
|
|
|
207 |
return fig
|
208 |
|
209 |
|
210 |
+
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster, progress=gr.Progress()):
|
211 |
# Filter the DataFrame to only include frames from the largest cluster
|
212 |
df_largest_cluster = df[df['Cluster'] == largest_cluster]
|
213 |
mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
|
|
|
255 |
plt.tight_layout()
|
256 |
|
257 |
line = None
|
258 |
+
|
259 |
+
# Add progress tracking
|
260 |
+
progress(0.9, desc="Generating video with heatmap")
|
261 |
+
|
262 |
for frame_count in range(total_frames):
|
263 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
|
264 |
ret, frame = cap.read()
|
|
|
267 |
|
268 |
if line:
|
269 |
line.remove()
|
270 |
+
line = ax.axvline(x=frame_count, color='blue', linewidth=3)
|
271 |
|
272 |
canvas = FigureCanvasAgg(fig)
|
273 |
canvas.draw()
|
|
|
285 |
cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
286 |
|
287 |
out.write(combined_frame)
|
288 |
+
|
289 |
+
# Update progress
|
290 |
+
progress(0.9 + (0.1 * (frame_count + 1) / total_frames), desc="Generating video with heatmap")
|
291 |
|
292 |
cap.release()
|
293 |
out.release()
|