reab5555 commited on
Commit
3ad6751
·
verified ·
1 Parent(s): 50833e2

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +14 -9
visualization.py CHANGED
@@ -218,9 +218,7 @@ def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
218
 
219
 
220
  def create_heatmap(frame_time, mse_embeddings, mse_posture, mse_voice):
221
- fig = Figure(figsize=(10, 1))
222
- canvas = FigureCanvas(fig)
223
- ax = fig.add_subplot(111)
224
  time_index = int(frame_time)
225
 
226
  if time_index < len(mse_embeddings) and time_index < len(mse_posture) and time_index < len(mse_voice):
@@ -231,6 +229,7 @@ def create_heatmap(frame_time, mse_embeddings, mse_posture, mse_voice):
231
  ax.barh(['Face', 'Posture', 'Voice'], mse_values, color=['navy', 'purple', 'green'])
232
  ax.set_xlim(0, 1) # Normalize the MSE values
233
 
 
234
  canvas.draw()
235
  img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
236
  img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
@@ -263,21 +262,27 @@ def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
263
  df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
264
  corr = df.corr()
265
 
266
- plt.figure(figsize=(10, 8))
 
267
  heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
268
  plt.title('Correlation Heatmap of MSEs')
269
  plt.tight_layout()
270
  return plt.gcf()
271
 
272
  def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
273
- fig = plt.figure(figsize=(10, 8))
 
274
  ax = fig.add_subplot(111, projection='3d')
275
 
276
- ax.scatter(mse_posture, mse_embeddings, mse_voice, c='r', marker='o')
 
 
 
 
277
  ax.set_xlabel('Body Posture MSE')
278
  ax.set_ylabel('Facial Features MSE')
279
  ax.set_zlabel('Voice MSE')
280
- ax.set_title('3D Scatter Plot of MSEs')
281
 
282
- plt.tight_layout()
283
- return plt.gcf()
 
 
218
 
219
 
220
  def create_heatmap(frame_time, mse_embeddings, mse_posture, mse_voice):
221
+ fig, ax = plt.subplots(figsize=(10, 1))
 
 
222
  time_index = int(frame_time)
223
 
224
  if time_index < len(mse_embeddings) and time_index < len(mse_posture) and time_index < len(mse_voice):
 
229
  ax.barh(['Face', 'Posture', 'Voice'], mse_values, color=['navy', 'purple', 'green'])
230
  ax.set_xlim(0, 1) # Normalize the MSE values
231
 
232
+ canvas = FigureCanvas(fig)
233
  canvas.draw()
234
  img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
235
  img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
 
262
  df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
263
  corr = df.corr()
264
 
265
+ plt.figure(figsize=(10, 8), dpi=300)
266
+
267
  heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
268
  plt.title('Correlation Heatmap of MSEs')
269
  plt.tight_layout()
270
  return plt.gcf()
271
 
272
  def plot_3d_scatter(mse_embeddings, mse_posture, mse_voice):
273
+ fig = plt.figure()
274
+ plt.figure(figsize=(16, 8), dpi=300)
275
  ax = fig.add_subplot(111, projection='3d')
276
 
277
+ # Scatter plot
278
+ ax.scatter(mse_posture, mse_embeddings, mse_voice, c=['purple']*len(mse_posture), label='Body Posture', alpha=0.6)
279
+ ax.scatter(mse_posture, mse_embeddings, mse_voice, c=['navy']*len(mse_embeddings), label='Facial Features', alpha=0.6)
280
+ ax.scatter(mse_posture, mse_embeddings, mse_voice, c=['green']*len(mse_voice), label='Voice', alpha=0.6)
281
+
282
  ax.set_xlabel('Body Posture MSE')
283
  ax.set_ylabel('Facial Features MSE')
284
  ax.set_zlabel('Voice MSE')
 
285
 
286
+ ax.legend()
287
+ plt.close(fig)
288
+ return fig