reab5555 commited on
Commit
e0599e6
·
verified ·
1 Parent(s): 0167254

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +8 -25
visualization.py CHANGED
@@ -217,30 +217,13 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
217
  plt.close()
218
  return fig
219
 
220
- def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width, largest_cluster):
221
  frame_count = int(t * video_fps)
222
 
223
- # Replace MSE values outside of the largest cluster with zeros
224
- mask = (largest_cluster == 1)
225
- mse_embeddings[~mask] = 0
226
- mse_posture[~mask] = 0
227
- mse_voice[~mask] = 0
228
-
229
- # Check if all values are zero
230
- if np.all(mse_embeddings == 0):
231
- mse_embeddings_norm = mse_embeddings
232
- else:
233
- mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
234
-
235
- if np.all(mse_posture == 0):
236
- mse_posture_norm = mse_posture
237
- else:
238
- mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
239
-
240
- if np.all(mse_voice == 0):
241
- mse_voice_norm = mse_voice
242
- else:
243
- mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice))
244
 
245
  combined_mse = np.zeros((3, total_frames))
246
  combined_mse[0] = mse_embeddings_norm
@@ -253,7 +236,7 @@ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_f
253
  ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
254
  ax.set_xticks([])
255
 
256
- ax.axvline(x=frame_count, color='black', linewidth=2)
257
 
258
  plt.tight_layout(pad=0.5)
259
 
@@ -291,7 +274,7 @@ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_v
291
 
292
  def combine_video_and_heatmap(t):
293
  video_frame = video.get_frame(t)
294
- heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, total_frames, width, largest_cluster)
295
  heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
296
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
297
  return combined_frame
@@ -326,4 +309,4 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
326
  heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
327
  plt.title('Correlation Heatmap of MSEs')
328
  plt.tight_layout()
329
- return plt.gcf()
 
217
  plt.close()
218
  return fig
219
 
220
+ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
221
  frame_count = int(t * video_fps)
222
 
223
+ # Normalize MSE values
224
+ mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
225
+ mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
226
+ mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  combined_mse = np.zeros((3, total_frames))
229
  combined_mse[0] = mse_embeddings_norm
 
236
  ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
237
  ax.set_xticks([])
238
 
239
+ ax.axvline(x=frame_count, color='black', linewidth=3)
240
 
241
  plt.tight_layout(pad=0.5)
242
 
 
274
 
275
  def combine_video_and_heatmap(t):
276
  video_frame = video.get_frame(t)
277
+ heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, total_frames, width)
278
  heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
279
  combined_frame = np.vstack((video_frame, heatmap_frame_resized))
280
  return combined_frame
 
309
  heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
310
  plt.title('Correlation Heatmap of MSEs')
311
  plt.tight_layout()
312
+ return plt.gcf()