reab5555 commited on
Commit
77142e3
·
verified ·
1 Parent(s): 3fab580

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +16 -3
visualization.py CHANGED
@@ -149,8 +149,18 @@ def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title, anomaly
149
  segments.append(current_segment)
150
  return segments
151
 
 
 
 
 
 
 
 
 
 
 
152
  # Plot each data series
153
- for mse_values, color, label in zip([mse_embeddings, mse_posture, mse_voice],
154
  ['navy', 'purple', 'green'],
155
  ['Facial Features', 'Body Posture', 'Voice']):
156
  segments = get_continuous_segments(df['Seconds'], mse_values)
@@ -176,6 +186,10 @@ def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title, anomaly
176
  threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
177
  ax.axhline(y=threshold, color=color, linestyle=':', alpha=0.5, label=f'{label} Threshold')
178
 
 
 
 
 
179
  max_seconds = df['Seconds'].max()
180
  num_ticks = 100
181
  tick_locations = np.linspace(0, max_seconds, num_ticks)
@@ -185,7 +199,7 @@ def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title, anomaly
185
  ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
186
 
187
  ax.set_xlabel('Timecode')
188
- ax.set_ylabel('Mean Squared Error')
189
  ax.set_title(title)
190
 
191
  ax.grid(True, linestyle='--', alpha=0.7)
@@ -194,7 +208,6 @@ def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title, anomaly
194
  plt.close()
195
  return fig
196
 
197
-
198
  def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
199
  plt.figure(figsize=(16, 3), dpi=300)
200
  fig, ax = plt.subplots(figsize=(16, 3))
 
149
  segments.append(current_segment)
150
  return segments
151
 
152
+ # Scale all MSE values to the same range (0 to 1)
153
+ def scale_mse(mse_values):
154
+ min_val = np.min(mse_values)
155
+ max_val = np.max(mse_values)
156
+ return (mse_values - min_val) / (max_val - min_val)
157
+
158
+ mse_embeddings_scaled = scale_mse(mse_embeddings)
159
+ mse_posture_scaled = scale_mse(mse_posture)
160
+ mse_voice_scaled = scale_mse(mse_voice)
161
+
162
  # Plot each data series
163
+ for mse_values, color, label in zip([mse_embeddings_scaled, mse_posture_scaled, mse_voice_scaled],
164
  ['navy', 'purple', 'green'],
165
  ['Facial Features', 'Body Posture', 'Voice']):
166
  segments = get_continuous_segments(df['Seconds'], mse_values)
 
186
  threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
187
  ax.axhline(y=threshold, color=color, linestyle=':', alpha=0.5, label=f'{label} Threshold')
188
 
189
+ # Plot anomalies in red
190
+ anomalies = mse_values > threshold
191
+ ax.scatter(df['Seconds'][anomalies], mse_values[anomalies], color='red', s=20, zorder=5)
192
+
193
  max_seconds = df['Seconds'].max()
194
  num_ticks = 100
195
  tick_locations = np.linspace(0, max_seconds, num_ticks)
 
199
  ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
200
 
201
  ax.set_xlabel('Timecode')
202
+ ax.set_ylabel('Scaled Mean Squared Error')
203
  ax.set_title(title)
204
 
205
  ax.grid(True, linestyle='--', alpha=0.7)
 
208
  plt.close()
209
  return fig
210
 
 
211
  def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
212
  plt.figure(figsize=(16, 3), dpi=300)
213
  fig, ax = plt.subplots(figsize=(16, 3))