Update visualization.py
Browse files- visualization.py +37 -24
visualization.py
CHANGED
@@ -16,37 +16,50 @@ def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_thre
|
|
16 |
|
17 |
# Ensure df and mse_values have the same length and remove NaN values
|
18 |
min_length = min(len(df), len(mse_values))
|
19 |
-
df = df.iloc[:min_length]
|
20 |
mse_values = mse_values[:min_length]
|
21 |
|
22 |
-
# Remove NaN values
|
23 |
-
mask = ~np.isnan(mse_values)
|
24 |
-
df = df[mask]
|
25 |
-
mse_values = mse_values[mask]
|
26 |
-
|
27 |
-
# Calculate rolling mean and std
|
28 |
-
mean = pd.Series(mse_values).rolling(window=10, min_periods=1).mean()
|
29 |
-
std = pd.Series(mse_values).rolling(window=10, min_periods=1).std()
|
30 |
-
|
31 |
-
# Plot scatter points
|
32 |
-
ax.scatter(df['Seconds'], mse_values, color=color, alpha=0.3, s=5)
|
33 |
-
|
34 |
-
# Plot mean line and std fill only for continuous valid segments
|
35 |
valid_mask = ~np.isnan(mse_values)
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
for segment in segments:
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
median = np.median(mse_values)
|
47 |
ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
|
48 |
|
49 |
-
# Add threshold line
|
50 |
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
|
51 |
ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}')
|
52 |
ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red')
|
|
|
16 |
|
17 |
# Ensure df and mse_values have the same length and remove NaN values
|
18 |
min_length = min(len(df), len(mse_values))
|
19 |
+
df = df.iloc[:min_length].copy()
|
20 |
mse_values = mse_values[:min_length]
|
21 |
|
22 |
+
# Remove NaN values and create a mask for valid data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
valid_mask = ~np.isnan(mse_values)
|
24 |
+
df = df[valid_mask]
|
25 |
+
mse_values = mse_values[valid_mask]
|
26 |
+
|
27 |
+
# Function to identify continuous segments
|
28 |
+
def get_continuous_segments(seconds, values, max_gap=1):
|
29 |
+
segments = []
|
30 |
+
current_segment = []
|
31 |
+
for i, (sec, val) in enumerate(zip(seconds, values)):
|
32 |
+
if not current_segment or (sec - current_segment[-1][0] <= max_gap):
|
33 |
+
current_segment.append((sec, val))
|
34 |
+
else:
|
35 |
+
segments.append(current_segment)
|
36 |
+
current_segment = [(sec, val)]
|
37 |
+
if current_segment:
|
38 |
+
segments.append(current_segment)
|
39 |
+
return segments
|
40 |
+
|
41 |
+
# Get continuous segments
|
42 |
+
segments = get_continuous_segments(df['Seconds'], mse_values)
|
43 |
+
|
44 |
+
# Plot each segment separately
|
45 |
for segment in segments:
|
46 |
+
segment_seconds, segment_mse = zip(*segment)
|
47 |
+
ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5)
|
48 |
+
|
49 |
+
# Calculate and plot rolling mean and std for this segment
|
50 |
+
if len(segment) > 1: # Only if there's more than one point in the segment
|
51 |
+
segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
|
52 |
+
segment_df = segment_df.sort_values('Seconds')
|
53 |
+
mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
|
54 |
+
std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
|
55 |
+
|
56 |
+
ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
|
57 |
+
ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
|
58 |
+
|
59 |
+
# Rest of the function remains the same
|
60 |
median = np.median(mse_values)
|
61 |
ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
|
62 |
|
|
|
63 |
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
|
64 |
ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}')
|
65 |
ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red')
|