reab5555 commited on
Commit
e0def5d
·
verified ·
1 Parent(s): 97ad667

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +110 -0
visualization.py CHANGED
@@ -216,3 +216,113 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
216
  plt.tight_layout()
217
  plt.close()
218
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  plt.tight_layout()
217
  plt.close()
218
  return fig
219
+
220
+
221
+
222
+ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
223
+ frame_count = int(t * video_fps)
224
+
225
+ # Normalize MSE values
226
+ mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
227
+ mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
228
+ mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice))
229
+
230
+ combined_mse = np.zeros((3, total_frames))
231
+ combined_mse[0] = mse_embeddings_norm
232
+ combined_mse[1] = mse_posture_norm
233
+ combined_mse[2] = mse_voice_norm
234
+
235
+ fig, ax = plt.subplots(figsize=(video_width / 300, 0.4))
236
+ ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
237
+ ax.set_yticks([0.5, 1.5, 2.5])
238
+ ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
239
+ ax.set_xticks([])
240
+
241
+ ax.axvline(x=frame_count, color='black', linewidth=2)
242
+
243
+ plt.tight_layout(pad=0.5)
244
+
245
+ canvas = FigureCanvas(fig)
246
+ canvas.draw()
247
+ heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
248
+ heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
249
+ plt.close(fig)
250
+ return heatmap_img
251
+
252
+ def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, largest_cluster):
253
+ print(f"Creating heatmap video. Output folder: {output_folder}")
254
+ os.makedirs(output_folder, exist_ok=True)
255
+ output_filename = os.path.basename(video_path).rsplit('.', 1)[0] + '_heatmap.mp4'
256
+ heatmap_video_path = os.path.join(output_folder, output_filename)
257
+ print(f"Heatmap video will be saved at: {heatmap_video_path}")
258
+
259
+ # Load the original video
260
+ video = VideoFileClip(video_path)
261
+
262
+ # Get video properties
263
+ width, height = video.w, video.h
264
+ total_frames = int(video.duration * video.fps)
265
+
266
+ def fill_with_previous_values(mse_array, total_frames):
267
+ result = np.zeros(total_frames)
268
+ indices = np.linspace(0, total_frames - 1, len(mse_array)).astype(int)
269
+ result[indices] = mse_array
270
+ for i in range(1, total_frames):
271
+ if result[i] == 0:
272
+ result[i] = result[i-1]
273
+ return result
274
+
275
+ # Fill gaps with previous values
276
+ mse_embeddings = fill_with_previous_values(mse_embeddings, total_frames)
277
+ mse_posture = fill_with_previous_values(mse_posture, total_frames)
278
+ mse_voice = fill_with_previous_values(mse_voice, total_frames)
279
+
280
+ def combine_video_and_heatmap(t):
281
+ video_frame = video.get_frame(t)
282
+ heatmap_frame = create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video.fps, total_frames, width)
283
+ heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
284
+
285
+ # Convert heatmap frame to RGB if it's RGBA
286
+ if heatmap_frame_resized.shape[2] == 4:
287
+ heatmap_frame_resized = cv2.cvtColor(heatmap_frame_resized, cv2.COLOR_RGBA2RGB)
288
+
289
+ # Ensure both frames have the same number of channels
290
+ if video_frame.shape[2] != heatmap_frame_resized.shape[2]:
291
+ if video_frame.shape[2] == 3:
292
+ heatmap_frame_resized = heatmap_frame_resized[:, :, :3] # Use only RGB channels
293
+ else:
294
+ video_frame = cv2.cvtColor(video_frame, cv2.COLOR_RGB2RGBA)
295
+
296
+ combined_frame = np.vstack((video_frame, heatmap_frame_resized))
297
+ return combined_frame
298
+
299
+ final_clip = VideoClip(combine_video_and_heatmap, duration=video.duration)
300
+ final_clip = final_clip.set_audio(video.audio)
301
+
302
+ # Write the final video
303
+ final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video.fps)
304
+
305
+ # Close the video clips
306
+ video.close()
307
+ final_clip.close()
308
+
309
+ if os.path.exists(heatmap_video_path):
310
+ print(f"Heatmap video created at: {heatmap_video_path}")
311
+ print(f"Heatmap video size: {os.path.getsize(heatmap_video_path)} bytes")
312
+ return heatmap_video_path
313
+ else:
314
+ print(f"Failed to create heatmap video at: {heatmap_video_path}")
315
+ return None
316
+
317
+ # Function to create the correlation heatmap
318
+ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
319
+ data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
320
+ df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
321
+ corr = df.corr()
322
+
323
+ plt.figure(figsize=(10, 8), dpi=300)
324
+
325
+ heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
326
+ plt.title('Correlation Heatmap of MSEs')
327
+ plt.tight_layout()
328
+ return plt.gcf()