reab5555 commited on
Commit
ba165b2
·
verified ·
1 Parent(s): bff0ca5

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +0 -95
visualization.py CHANGED
@@ -127,101 +127,6 @@ def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_thre
127
  plt.close()
128
  return fig, anomaly_frames
129
 
130
- def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title, anomaly_threshold=4, time_threshold=3):
131
- plt.figure(figsize=(16, 8), dpi=300)
132
- fig, ax = plt.subplots(figsize=(16, 8))
133
-
134
- if 'Seconds' not in df.columns:
135
- df['Seconds'] = df['Timecode'].apply(
136
- lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
137
-
138
- # Ensure the lengths of the DataFrame and MSE values are consistent
139
- min_length = min(len(df), len(mse_embeddings), len(mse_posture), len(mse_voice))
140
- df = df.iloc[:min_length].copy()
141
- mse_embeddings = mse_embeddings[:min_length]
142
- mse_posture = mse_posture[:min_length]
143
- mse_voice = mse_voice[:min_length]
144
-
145
- # Remove NaN values
146
- valid_mask = ~np.isnan(mse_embeddings) & ~np.isnan(mse_posture) & ~np.isnan(mse_voice)
147
- df = df[valid_mask]
148
- mse_embeddings = mse_embeddings[valid_mask]
149
- mse_posture = mse_posture[valid_mask]
150
- mse_voice = mse_voice[valid_mask]
151
-
152
- # Function to identify continuous segments
153
- def get_continuous_segments(seconds, values, max_gap=1):
154
- segments = []
155
- current_segment = []
156
- for i, (sec, val) in enumerate(zip(seconds, values)):
157
- if not current_segment or (sec - current_segment[-1][0] <= max_gap):
158
- current_segment.append((sec, val))
159
- else:
160
- segments.append(current_segment)
161
- current_segment = [(sec, val)]
162
- if current_segment:
163
- segments.append(current_segment)
164
- return segments
165
-
166
- # Scale all MSE values to the same range (0 to 1)
167
- def scale_mse(mse_values):
168
- min_val = np.min(mse_values)
169
- max_val = np.max(mse_values)
170
- return (mse_values - min_val) / (max_val - min_val)
171
-
172
- mse_embeddings_scaled = scale_mse(mse_embeddings)
173
- mse_posture_scaled = scale_mse(mse_posture)
174
- mse_voice_scaled = scale_mse(mse_voice)
175
-
176
- # Plot each data series
177
- for mse_values, color, label in zip([mse_embeddings_scaled, mse_posture_scaled, mse_voice_scaled],
178
- ['navy', 'purple', 'green'],
179
- ['Facial Features', 'Body Posture', 'Voice']):
180
- segments = get_continuous_segments(df['Seconds'], mse_values)
181
-
182
- for segment in segments:
183
- segment_seconds, segment_mse = zip(*segment)
184
- ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5, label=label if segment == segments[0] else "")
185
-
186
- if len(segment) > 1:
187
- segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
188
- segment_df = segment_df.sort_values('Seconds')
189
- mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
190
- std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
191
-
192
- ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
193
- ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
194
-
195
- # Plot median baseline for each series
196
- median = np.median(mse_values)
197
- ax.axhline(y=median, color=color, linestyle=':', alpha=0.5, label=f'{label} Baseline Median')
198
-
199
- # Plot threshold for each series
200
- threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
201
- ax.axhline(y=threshold, color=color, linestyle='--', alpha=1, label=f'{label} Anomaly Threshold')
202
-
203
- # Plot anomalies in red
204
- anomalies = mse_values > threshold
205
- ax.scatter(df['Seconds'][anomalies], mse_values[anomalies], color='red', s=20, zorder=5)
206
-
207
- max_seconds = df['Seconds'].max()
208
- num_ticks = 100
209
- tick_locations = np.linspace(0, max_seconds, num_ticks)
210
- tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
211
-
212
- ax.set_xticks(tick_locations)
213
- ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
214
-
215
- ax.set_xlabel('Timecode')
216
- ax.set_ylabel('Scaled Mean Squared Error')
217
- ax.set_title(title)
218
-
219
- ax.grid(True, linestyle='--', alpha=0.7)
220
- ax.legend()
221
- plt.tight_layout()
222
- plt.close()
223
- return fig
224
-
225
 
226
  def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
227
  plt.figure(figsize=(16, 3), dpi=300)
 
127
  plt.close()
128
  return fig, anomaly_frames
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
132
  plt.figure(figsize=(16, 3), dpi=300)