reab5555's picture
Update visualization.py
1ca7095 verified
raw
history blame
11.4 kB
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import numpy as np
import pandas as pd
import cv2
from matplotlib.patches import Rectangle
from utils import seconds_to_timecode
from anomaly_detection import determine_anomalies
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
plt.figure(figsize=(16, 8), dpi=300)
fig, ax = plt.subplots(figsize=(16, 8))
if 'Seconds' not in df.columns:
df['Seconds'] = df['Timecode'].apply(
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
# Ensure df and mse_values have the same length and remove NaN values
min_length = min(len(df), len(mse_values))
df = df.iloc[:min_length].copy()
mse_values = mse_values[:min_length]
# Remove NaN values and create a mask for valid data
valid_mask = ~np.isnan(mse_values)
df = df[valid_mask]
mse_values = mse_values[valid_mask]
# Function to identify continuous segments
def get_continuous_segments(seconds, values, max_gap=1):
segments = []
current_segment = []
for i, (sec, val) in enumerate(zip(seconds, values)):
if not current_segment or (sec - current_segment[-1][0] <= max_gap):
current_segment.append((sec, val))
else:
segments.append(current_segment)
current_segment = [(sec, val)]
if current_segment:
segments.append(current_segment)
return segments
# Get continuous segments
segments = get_continuous_segments(df['Seconds'], mse_values)
# Plot each segment separately
for segment in segments:
segment_seconds, segment_mse = zip(*segment)
ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5)
# Calculate and plot rolling mean and std for this segment
if len(segment) > 1: # Only if there's more than one point in the segment
segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
segment_df = segment_df.sort_values('Seconds')
mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
# Rest of the function remains the same
median = np.median(mse_values)
ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}')
ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red')
anomalies = determine_anomalies(mse_values, anomaly_threshold)
anomaly_frames = df['Frame'].iloc[anomalies].tolist()
ax.scatter(df['Seconds'].iloc[anomalies], mse_values[anomalies], color='red', s=20, zorder=5)
anomaly_data = list(zip(df['Timecode'].iloc[anomalies],
df['Seconds'].iloc[anomalies],
mse_values[anomalies]))
anomaly_data.sort(key=lambda x: x[1])
grouped_anomalies = []
current_group = []
for timecode, sec, mse in anomaly_data:
if not current_group or sec - current_group[-1][1] <= time_threshold:
current_group.append((timecode, sec, mse))
else:
grouped_anomalies.append(current_group)
current_group = [(timecode, sec, mse)]
if current_group:
grouped_anomalies.append(current_group)
for group in grouped_anomalies:
start_sec = group[0][1]
end_sec = group[-1][1]
rect = Rectangle((start_sec, ax.get_ylim()[0]), end_sec - start_sec, ax.get_ylim()[1] - ax.get_ylim()[0],
facecolor='red', alpha=0.2, zorder=1)
ax.add_patch(rect)
for group in grouped_anomalies:
highest_mse_anomaly = max(group, key=lambda x: x[2])
timecode, sec, mse = highest_mse_anomaly
ax.annotate(timecode, (sec, mse), textcoords="offset points", xytext=(0, 10),
ha='center', fontsize=6, color='red')
max_seconds = df['Seconds'].max()
num_ticks = 100
tick_locations = np.linspace(0, max_seconds, num_ticks)
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
ax.set_xlabel('Timecode')
ax.set_ylabel('Mean Squared Error')
ax.set_title(title)
ax.grid(True, linestyle='--', alpha=0.7)
ax.legend()
plt.tight_layout()
plt.close()
return fig, anomaly_frames
def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
plt.figure(figsize=(16, 3), dpi=300)
fig, ax = plt.subplots(figsize=(16, 3))
ax.hist(mse_values, bins=100, edgecolor='black', color=color, alpha=0.7)
ax.set_xlabel('Mean Squared Error')
ax.set_ylabel('Number of Samples')
ax.set_title(title)
mean = np.mean(mse_values)
std = np.std(mse_values)
threshold = mean + anomaly_threshold * std
ax.axvline(x=threshold, color='red', linestyle='--', linewidth=2)
plt.tight_layout()
plt.close()
return fig
def plot_mse_heatmap(mse_values, title, df):
plt.figure(figsize=(20, 3), dpi=300)
fig, ax = plt.subplots(figsize=(20, 3))
# Reshape MSE values to 2D array for heatmap
mse_2d = mse_values.reshape(1, -1)
# Create heatmap
sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax)
# Set x-axis ticks to timecodes
num_ticks = 60
tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int)
tick_labels = [df['Timecode'].iloc[i] for i in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
ax.set_title(title)
# Remove y-axis labels
ax.set_yticks([])
plt.tight_layout()
plt.close()
return fig
def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
plt.figure(figsize=(16, 8), dpi=300)
fig, ax = plt.subplots(figsize=(16, 8))
df['Seconds'] = df['Timecode'].apply(
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
posture_data = [(frame, score) for frame, score in posture_scores.items() if score is not None]
posture_frames, posture_scores = zip(*posture_data)
# Create a new dataframe for posture data
posture_df = pd.DataFrame({'Frame': posture_frames, 'Score': posture_scores})
posture_df = posture_df.merge(df[['Frame', 'Seconds']], on='Frame', how='inner')
ax.scatter(posture_df['Seconds'], posture_df['Score'], color=color, alpha=0.3, s=5)
mean = posture_df['Score'].rolling(window=10).mean()
ax.plot(posture_df['Seconds'], mean, color=color, linewidth=0.5)
ax.set_xlabel('Timecode')
ax.set_ylabel('Posture Score')
ax.set_title("Body Posture Over Time")
ax.grid(True, linestyle='--', alpha=0.7)
max_seconds = df['Seconds'].max()
num_ticks = 80
tick_locations = np.linspace(0, max_seconds, num_ticks)
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
plt.tight_layout()
plt.close()
return fig
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, output_path, desired_fps, largest_cluster):
# Filter the DataFrame to only include frames from the largest cluster
df_largest_cluster = df[df['Cluster'] == largest_cluster]
mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster]
mse_posture = mse_posture[df['Cluster'] == largest_cluster]
cap = cv2.VideoCapture(video_path)
original_fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, original_fps, (width, height + 200))
mse_embeddings = np.interp(np.linspace(0, len(mse_embeddings) - 1, total_frames),
np.arange(len(mse_embeddings)), mse_embeddings)
mse_posture = np.interp(np.linspace(0, len(mse_posture) - 1, total_frames),
np.arange(len(mse_posture)), mse_posture)
mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings))
mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture))
combined_mse = np.zeros((2, total_frames))
combined_mse[0] = mse_embeddings_norm # Use normalized MSE values for facial
combined_mse[1] = mse_posture_norm # Use normalized MSE values for posture
# Custom colormap definition
cdict = {
'red': [(0.0, 0.0, 0.0), # Low MSE: No red
(1.0, 1.0, 1.0)], # High MSE: Full red
'green': [(0.0, 1.0, 1.0), # Low MSE: Full green
(1.0, 0.0, 0.0)], # High MSE: No green
'blue': [(0.0, 1.0, 1.0), # Low MSE: Full blue
(1.0, 0.0, 0.0)] # High MSE: No blue
}
custom_cmap = LinearSegmentedColormap('custom_cmap', segmentdata=cdict, N=256)
fig, ax = plt.subplots(figsize=(width/100, 2))
# Use the custom colormap in the heatmap generation
im = ax.imshow(combined_mse, aspect='auto', cmap=custom_cmap, extent=[0, total_frames, 0, 2])
ax.set_yticks([0.5, 1.5])
ax.set_yticklabels(['Face', 'Posture'])
ax.set_xticks([])
plt.tight_layout()
line = None
for frame_count in range(total_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count)
ret, frame = cap.read()
if not ret:
break
if line:
line.remove()
line = ax.axvline(x=frame_count, color='r', linewidth=2)
canvas = FigureCanvasAgg(fig)
canvas.draw()
heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,))
heatmap_img = cv2.resize(heatmap_img, (width, 200))
# Convert heatmap_img from RGB to BGR
heatmap_img = cv2.cvtColor(heatmap_img, cv2.COLOR_RGB2BGR)
combined_frame = np.vstack((frame, heatmap_img))
seconds = frame_count / original_fps
timecode = f"{int(seconds//3600):02d}:{int((seconds%3600)//60):02d}:{int(seconds%60):02d}"
cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
out.write(combined_frame)
cap.release()
out.release()
plt.close(fig)
return output_path