reab5555 commited on
Commit
4f600d6
·
verified ·
1 Parent(s): 3dff9c0

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +35 -0
visualization.py CHANGED
@@ -154,6 +154,41 @@ def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title):
154
  plt.close()
155
  return fig
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
158
  plt.figure(figsize=(16, 3), dpi=300)
159
  fig, ax = plt.subplots(figsize=(16, 3))
 
154
  plt.close()
155
  return fig
156
 
157
+ def plot_combined_heatmap(mse_embeddings, mse_posture, mse_voice, df):
158
+ # Normalize MSE values
159
+ mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
160
+ mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
161
+ mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice))
162
+
163
+ combined_mse = np.zeros((3, len(mse_embeddings)))
164
+ combined_mse[0] = mse_embeddings_norm
165
+ combined_mse[1] = mse_posture_norm
166
+ combined_mse[2] = mse_voice_norm
167
+
168
+ plt.figure(figsize=(20, 3), dpi=300)
169
+ fig, ax = plt.subplots(figsize=(20, 3))
170
+
171
+ # Create heatmap
172
+ sns.heatmap(combined_mse, cmap='Reds', cbar=False, ax=ax)
173
+
174
+ # Set y-axis labels
175
+ ax.set_yticks([0.5, 1.5, 2.5])
176
+ ax.set_yticklabels(['Facial Features', 'Body Posture', 'Voice'], fontsize=8)
177
+
178
+ # Set x-axis ticks to timecodes
179
+ num_ticks = min(60, len(mse_embeddings))
180
+ tick_locations = np.linspace(0, len(mse_embeddings) - 1, num_ticks).astype(int)
181
+ tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations]
182
+
183
+ ax.set_xticks(tick_locations)
184
+ ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
185
+
186
+ ax.set_title('Combined MSE Heatmap')
187
+
188
+ plt.tight_layout()
189
+ plt.close(fig)
190
+ return fig
191
+
192
  def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
193
  plt.figure(figsize=(16, 3), dpi=300)
194
  fig, ax = plt.subplots(figsize=(16, 3))