reab5555 commited on
Commit
92bbbb6
·
verified ·
1 Parent(s): 26ad44b

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +5 -4
visualization.py CHANGED
@@ -216,7 +216,6 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
216
  plt.close()
217
  return fig
218
 
219
-
220
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
221
  frame_count = int(t * video_fps)
222
 
@@ -230,14 +229,16 @@ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_f
230
  combined_mse[1] = mse_posture_norm
231
  combined_mse[2] = mse_voice_norm
232
 
233
- fig, ax = plt.subplots(figsize=(video_width / 50, 0.5))
234
  ax.imshow(combined_mse, aspect='auto', cmap='hot', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
235
  ax.set_yticks([0.5, 1.5, 2.5])
236
- ax.set_yticklabels(['Voice', 'Posture', 'Face'])
237
  ax.set_xticks([])
238
 
239
- ax.axvline(x=frame_count, color='black', linewidth=3)
240
 
 
 
241
  canvas = FigureCanvas(fig)
242
  canvas.draw()
243
  heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
 
216
  plt.close()
217
  return fig
218
 
 
219
  def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_frames, video_width):
220
  frame_count = int(t * video_fps)
221
 
 
229
  combined_mse[1] = mse_posture_norm
230
  combined_mse[2] = mse_voice_norm
231
 
232
+ fig, ax = plt.subplots(figsize=(video_width / 25, 0.2)) # Much thinner height, wider width
233
  ax.imshow(combined_mse, aspect='auto', cmap='hot', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
234
  ax.set_yticks([0.5, 1.5, 2.5])
235
+ ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=6) # Smaller font size
236
  ax.set_xticks([])
237
 
238
+ ax.axvline(x=frame_count, color='black', linewidth=2) # Thinner line for smaller heatmap
239
 
240
+ plt.tight_layout(pad=0.1) # Reduce padding around the plot
241
+
242
  canvas = FigureCanvas(fig)
243
  canvas.draw()
244
  heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')