Update visualization.py
Browse files- visualization.py +8 -2
visualization.py
CHANGED
@@ -217,9 +217,15 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
|
|
217 |
plt.close()
|
218 |
return fig
|
219 |
|
220 |
-
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
|
221 |
frame_count = int(t * video_fps)
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
# Normalize MSE values
|
224 |
mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
|
225 |
mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
|
@@ -274,7 +280,7 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
|
|
274 |
|
275 |
def combine_video_and_heatmap(t):
|
276 |
video_frame = video.get_frame(t)
|
277 |
-
heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, total_frames, width)
|
278 |
heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
|
279 |
combined_frame = np.vstack((video_frame, heatmap_frame_resized))
|
280 |
return combined_frame
|
|
|
217 |
plt.close()
|
218 |
return fig
|
219 |
|
220 |
+
def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width, largest_cluster):
|
221 |
frame_count = int(t * video_fps)
|
222 |
|
223 |
+
# Replace MSE values outside of the largest cluster with zeros
|
224 |
+
mask = (largest_cluster == 1)
|
225 |
+
mse_embeddings[~mask] = 0
|
226 |
+
mse_posture[~mask] = 0
|
227 |
+
mse_voice[~mask] = 0
|
228 |
+
|
229 |
# Normalize MSE values
|
230 |
mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
|
231 |
mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
|
|
|
280 |
|
281 |
def combine_video_and_heatmap(t):
|
282 |
video_frame = video.get_frame(t)
|
283 |
+
heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, total_frames, width, largest_cluster)
|
284 |
heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
|
285 |
combined_frame = np.vstack((video_frame, heatmap_frame_resized))
|
286 |
return combined_frame
|