reab5555 commited on
Commit
79bd6b6
·
verified ·
1 Parent(s): 11e16fe

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +37 -24
visualization.py CHANGED
@@ -16,37 +16,50 @@ def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_thre
16
 
17
  # Ensure df and mse_values have the same length and remove NaN values
18
  min_length = min(len(df), len(mse_values))
19
- df = df.iloc[:min_length]
20
  mse_values = mse_values[:min_length]
21
 
22
- # Remove NaN values
23
- mask = ~np.isnan(mse_values)
24
- df = df[mask]
25
- mse_values = mse_values[mask]
26
-
27
- # Calculate rolling mean and std
28
- mean = pd.Series(mse_values).rolling(window=10, min_periods=1).mean()
29
- std = pd.Series(mse_values).rolling(window=10, min_periods=1).std()
30
-
31
- # Plot scatter points
32
- ax.scatter(df['Seconds'], mse_values, color=color, alpha=0.3, s=5)
33
-
34
- # Plot mean line and std fill only for continuous valid segments
35
  valid_mask = ~np.isnan(mse_values)
36
- segments = np.split(np.arange(len(df)), np.where(~valid_mask)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  for segment in segments:
38
- if len(segment) > 0 and valid_mask[segment[0]]:
39
- ax.plot(df['Seconds'].iloc[segment], mean.iloc[segment], color=color, linewidth=0.5)
40
- ax.fill_between(df['Seconds'].iloc[segment],
41
- mean.iloc[segment] - std.iloc[segment],
42
- mean.iloc[segment] + std.iloc[segment],
43
- color=color, alpha=0.1)
44
-
45
- # Add median line
 
 
 
 
 
 
46
  median = np.median(mse_values)
47
  ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
48
 
49
- # Add threshold line
50
  threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
51
  ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}')
52
  ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red')
 
16
 
17
  # Ensure df and mse_values have the same length and remove NaN values
18
  min_length = min(len(df), len(mse_values))
19
+ df = df.iloc[:min_length].copy()
20
  mse_values = mse_values[:min_length]
21
 
22
+ # Remove NaN values and create a mask for valid data
 
 
 
 
 
 
 
 
 
 
 
 
23
  valid_mask = ~np.isnan(mse_values)
24
+ df = df[valid_mask]
25
+ mse_values = mse_values[valid_mask]
26
+
27
+ # Function to identify continuous segments
28
+ def get_continuous_segments(seconds, values, max_gap=1):
29
+ segments = []
30
+ current_segment = []
31
+ for i, (sec, val) in enumerate(zip(seconds, values)):
32
+ if not current_segment or (sec - current_segment[-1][0] <= max_gap):
33
+ current_segment.append((sec, val))
34
+ else:
35
+ segments.append(current_segment)
36
+ current_segment = [(sec, val)]
37
+ if current_segment:
38
+ segments.append(current_segment)
39
+ return segments
40
+
41
+ # Get continuous segments
42
+ segments = get_continuous_segments(df['Seconds'], mse_values)
43
+
44
+ # Plot each segment separately
45
  for segment in segments:
46
+ segment_seconds, segment_mse = zip(*segment)
47
+ ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5)
48
+
49
+ # Calculate and plot rolling mean and std for this segment
50
+ if len(segment) > 1: # Only if there's more than one point in the segment
51
+ segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
52
+ segment_df = segment_df.sort_values('Seconds')
53
+ mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
54
+ std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
55
+
56
+ ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
57
+ ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
58
+
59
+ # Rest of the function remains the same
60
  median = np.median(mse_values)
61
  ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
62
 
 
63
  threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
64
  ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}')
65
  ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red')