Update visualization.py
Browse files- visualization.py +50 -10
visualization.py
CHANGED
@@ -127,7 +127,7 @@ def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_thre
|
|
127 |
plt.close()
|
128 |
return fig, anomaly_frames
|
129 |
|
130 |
-
def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title):
|
131 |
plt.figure(figsize=(16, 8), dpi=300)
|
132 |
fig, ax = plt.subplots(figsize=(16, 8))
|
133 |
|
@@ -135,17 +135,57 @@ def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title):
|
|
135 |
df['Seconds'] = df['Timecode'].apply(
|
136 |
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
|
137 |
|
138 |
-
#
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
ax.
|
144 |
-
ax.
|
145 |
-
ax.plot(df['Seconds'], mse_voice, color='green', label='Voice')
|
146 |
|
147 |
ax.set_xlabel('Timecode')
|
148 |
-
ax.set_ylabel('
|
149 |
ax.set_title(title)
|
150 |
|
151 |
ax.grid(True, linestyle='--', alpha=0.7)
|
@@ -257,7 +297,7 @@ def create_heatmap(t, mse_embeddings, mse_posture, mse_voice, video_fps, total_f
|
|
257 |
combined_mse[1] = mse_posture_norm
|
258 |
combined_mse[2] = mse_voice_norm
|
259 |
|
260 |
-
fig, ax = plt.subplots(figsize=(video_width / 200, 0.
|
261 |
ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
|
262 |
ax.set_yticks([0.5, 1.5, 2.5])
|
263 |
ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
|
|
|
127 |
plt.close()
|
128 |
return fig, anomaly_frames
|
129 |
|
130 |
+
def plot_combined_mse(df, mse_embeddings, mse_posture, mse_voice, title, anomaly_threshold=4, time_threshold=3):
|
131 |
plt.figure(figsize=(16, 8), dpi=300)
|
132 |
fig, ax = plt.subplots(figsize=(16, 8))
|
133 |
|
|
|
135 |
df['Seconds'] = df['Timecode'].apply(
|
136 |
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
|
137 |
|
138 |
+
# Function to identify continuous segments
|
139 |
+
def get_continuous_segments(seconds, values, max_gap=1):
|
140 |
+
segments = []
|
141 |
+
current_segment = []
|
142 |
+
for i, (sec, val) in enumerate(zip(seconds, values)):
|
143 |
+
if not current_segment or (sec - current_segment[-1][0] <= max_gap):
|
144 |
+
current_segment.append((sec, val))
|
145 |
+
else:
|
146 |
+
segments.append(current_segment)
|
147 |
+
current_segment = [(sec, val)]
|
148 |
+
if current_segment:
|
149 |
+
segments.append(current_segment)
|
150 |
+
return segments
|
151 |
+
|
152 |
+
# Plot each data series
|
153 |
+
for mse_values, color, label in zip([mse_embeddings, mse_posture, mse_voice],
|
154 |
+
['navy', 'purple', 'green'],
|
155 |
+
['Facial Features', 'Body Posture', 'Voice']):
|
156 |
+
segments = get_continuous_segments(df['Seconds'], mse_values)
|
157 |
+
|
158 |
+
for segment in segments:
|
159 |
+
segment_seconds, segment_mse = zip(*segment)
|
160 |
+
ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5, label=label if segment == segments[0] else "")
|
161 |
+
|
162 |
+
if len(segment) > 1:
|
163 |
+
segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
|
164 |
+
segment_df = segment_df.sort_values('Seconds')
|
165 |
+
mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
|
166 |
+
std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
|
167 |
+
|
168 |
+
ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
|
169 |
+
ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
|
170 |
+
|
171 |
+
# Plot median baseline for each series
|
172 |
+
median = np.median(mse_values)
|
173 |
+
ax.axhline(y=median, color=color, linestyle='--', alpha=0.5, label=f'{label} Median')
|
174 |
+
|
175 |
+
# Plot threshold for each series
|
176 |
+
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
|
177 |
+
ax.axhline(y=threshold, color=color, linestyle=':', alpha=0.5, label=f'{label} Threshold')
|
178 |
+
|
179 |
+
max_seconds = df['Seconds'].max()
|
180 |
+
num_ticks = 100
|
181 |
+
tick_locations = np.linspace(0, max_seconds, num_ticks)
|
182 |
+
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
|
183 |
|
184 |
+
ax.set_xticks(tick_locations)
|
185 |
+
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
|
|
|
186 |
|
187 |
ax.set_xlabel('Timecode')
|
188 |
+
ax.set_ylabel('Mean Squared Error')
|
189 |
ax.set_title(title)
|
190 |
|
191 |
ax.grid(True, linestyle='--', alpha=0.7)
|
|
|
297 |
combined_mse[1] = mse_posture_norm
|
298 |
combined_mse[2] = mse_voice_norm
|
299 |
|
300 |
+
fig, ax = plt.subplots(figsize=(video_width / 200, 0.4))
|
301 |
ax.imshow(combined_mse, aspect='auto', cmap='Reds', vmin=0, vmax=1, extent=[0, total_frames, 0, 3])
|
302 |
ax.set_yticks([0.5, 1.5, 2.5])
|
303 |
ax.set_yticklabels(['Voice', 'Posture', 'Face'], fontsize=7)
|