reab5555's picture
Update visualization.py
68302f1 verified
raw
history blame
14.9 kB
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import numpy as np
import pandas as pd
import cv2
from moviepy.editor import VideoFileClip, AudioFileClip, CompositeVideoClip, ImageClip, VideoClip, concatenate_videoclips
from moviepy.video.fx.all import resize
from PIL import Image, ImageDraw, ImageFont
from matplotlib.patches import Rectangle
from scipy import interpolate
import os
# Utility functions
def seconds_to_timecode(seconds):
hours = seconds // 3600
minutes = (seconds % 3600) // 60
seconds = seconds % 60
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
def determine_anomalies(values, threshold):
mean = np.mean(values)
std = np.std(values)
anomalies = np.where(values > mean + threshold * std)[0]
return anomalies
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
plt.figure(figsize=(16, 8), dpi=300)
fig, ax = plt.subplots(figsize=(16, 8))
if 'Seconds' not in df.columns:
df['Seconds'] = df['Timecode'].apply(
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
# Ensure df and mse_values have the same length and remove NaN values
min_length = min(len(df), len(mse_values))
df = df.iloc[:min_length].copy()
mse_values = mse_values[:min_length]
# Remove NaN values and create a mask for valid data
valid_mask = ~np.isnan(mse_values)
df = df[valid_mask]
mse_values = mse_values[valid_mask]
# Function to identify continuous segments
def get_continuous_segments(seconds, values, max_gap=1):
segments = []
current_segment = []
for i, (sec, val) in enumerate(zip(seconds, values)):
if not current_segment or (sec - current_segment[-1][0] <= max_gap):
current_segment.append((sec, val))
else:
segments.append(current_segment)
current_segment = [(sec, val)]
if current_segment:
segments.append(current_segment)
return segments
# Get continuous segments
segments = get_continuous_segments(df['Seconds'], mse_values)
# Plot each segment separately
for segment in segments:
segment_seconds, segment_mse = zip(*segment)
ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5)
# Calculate and plot rolling mean and std for this segment
if len(segment) > 1: # Only if there's more than one point in the segment
segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
segment_df = segment_df.sort_values('Seconds')
mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
median = np.median(mse_values)
ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
ax.axhline(y=threshold, color='red', linestyle='--', label=f'Anomaly Threshold')
ax.text(ax.get_xlim()[1], threshold, f'Anomaly Threshold', verticalalignment='center', horizontalalignment='left', color='red')
anomalies = determine_anomalies(mse_values, anomaly_threshold)
anomaly_frames = df['Frame'].iloc[anomalies].tolist()
ax.scatter(df['Seconds'].iloc[anomalies], mse_values[anomalies], color='red', s=20, zorder=5)
anomaly_data = list(zip(df['Timecode'].iloc[anomalies],
df['Seconds'].iloc[anomalies],
mse_values[anomalies]))
anomaly_data.sort(key=lambda x: x[1])
grouped_anomalies = []
current_group = []
for timecode, sec, mse in anomaly_data:
if not current_group or sec - current_group[-1][1] <= time_threshold:
current_group.append((timecode, sec, mse))
else:
grouped_anomalies.append(current_group)
current_group = [(timecode, sec, mse)]
if current_group:
grouped_anomalies.append(current_group)
for group in grouped_anomalies:
start_sec = group[0][1]
end_sec = group[-1][1]
rect = Rectangle((start_sec, ax.get_ylim()[0]), end_sec - start_sec, ax.get_ylim()[1] - ax.get_ylim()[0],
facecolor='red', alpha=0.2, zorder=1)
ax.add_patch(rect)
for group in grouped_anomalies:
highest_mse_anomaly = max(group, key=lambda x: x[2])
timecode, sec, mse = highest_mse_anomaly
ax.annotate(timecode, (sec, mse), textcoords="offset points", xytext=(0, 10),
ha='center', fontsize=6, color='red')
max_seconds = df['Seconds'].max()
num_ticks = 100
tick_locations = np.linspace(0, max_seconds, num_ticks)
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
ax.set_xlabel('Timecode')
ax.set_ylabel('Mean Squared Error')
ax.set_title(title)
ax.grid(True, linestyle='--', alpha=0.7)
ax.legend()
plt.tight_layout()
plt.close()
return fig, anomaly_frames
def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
plt.figure(figsize=(16, 3), dpi=300)
fig, ax = plt.subplots(figsize=(16, 3))
ax.hist(mse_values, bins=100, edgecolor='black', color=color, alpha=0.7)
ax.set_xlabel('Mean Squared Error')
ax.set_ylabel('Number of Frames')
ax.set_title(title)
mean = np.mean(mse_values)
std = np.std(mse_values)
threshold = mean + anomaly_threshold * std
ax.axvline(x=threshold, color='red', linestyle='--', linewidth=2)
plt.tight_layout()
plt.close()
return fig
def plot_mse_heatmap(mse_values, title, df):
plt.figure(figsize=(20, 3), dpi=300)
fig, ax = plt.subplots(figsize=(20, 3))
# Reshape MSE values to 2D array for heatmap
mse_2d = mse_values.reshape(1, -1)
# Create heatmap
sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax)
# Set x-axis ticks to timecodes
num_ticks = min(60, len(mse_values))
tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int)
# Ensure tick_locations are within bounds
tick_locations = tick_locations[tick_locations < len(df)]
tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
ax.set_title(title)
# Remove y-axis labels
ax.set_yticks([])
plt.tight_layout()
plt.close()
return fig
def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
plt.figure(figsize=(16, 8), dpi=300)
fig, ax = plt.subplots(figsize=(16, 8))
df['Seconds'] = df['Timecode'].apply(
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
posture_data = [(frame, score) for frame, score in posture_scores.items() if score is not None]
posture_frames, posture_scores = zip(*posture_data)
# Create a new dataframe for posture data
posture_df = pd.DataFrame({'Frame': posture_frames, 'Score': posture_scores})
posture_df = posture_df.merge(df[['Frame', 'Seconds']], on='Frame', how='inner')
ax.scatter(posture_df['Seconds'], posture_df['Score'], color=color, alpha=0.3, s=5)
mean = posture_df['Score'].rolling(window=10).mean()
ax.plot(posture_df['Seconds'], mean, color=color, linewidth=0.5)
ax.set_xlabel('Timecode')
ax.set_ylabel('Posture Score')
ax.set_title("Body Posture Over Time")
ax.grid(True, linestyle='--', alpha=0.7)
max_seconds = df['Seconds'].max()
num_ticks = 80
tick_locations = np.linspace(0, max_seconds, num_ticks)
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
plt.tight_layout()
plt.close()
return fig
def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voice, most_frequent_person_frames):
# Ensure most_frequent_person_frames is a list
if not isinstance(most_frequent_person_frames, (list, np.ndarray)):
most_frequent_person_frames = [most_frequent_person_frames]
# Ensure df and mse arrays have the same length
min_length = min(len(df), len(mse_embeddings), len(mse_posture), len(mse_voice))
df = df.iloc[:min_length].copy()
mse_embeddings = mse_embeddings[:min_length]
mse_posture = mse_posture[:min_length]
mse_voice = mse_voice[:min_length]
# Create a mask for the most frequent person frames
mask = df['Frame'].isin(most_frequent_person_frames)
# Pad mask to match the length of the video frames
padded_mask = np.zeros(len(mse_embeddings), dtype=bool)
padded_mask[:len(mask)] = mask
# Apply the mask to filter the MSE arrays
mse_embeddings_filtered = np.where(padded_mask, mse_embeddings, 0)
mse_posture_filtered = np.where(padded_mask, mse_posture, 0)
mse_voice_filtered = np.where(padded_mask, mse_voice, 0)
return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered
def normalize_mse(mse):
return mse / np.max(mse) if np.max(mse) > 0 else mse
def pad_or_trim_array(arr, target_length):
if len(arr) > target_length:
return arr[:target_length]
elif len(arr) < target_length:
return np.pad(arr, (0, target_length - len(arr)), 'constant')
return arr
def create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, fps, total_frames, width):
frame_index = min(int(t * fps), len(mse_embeddings_filtered) - 1)
# Normalize the MSE values
mse_embeddings_norm = normalize_mse(mse_embeddings_filtered)
mse_posture_norm = normalize_mse(mse_posture_filtered)
mse_voice_norm = normalize_mse(mse_voice_filtered)
# Ensure all arrays have the correct length
mse_embeddings_norm = pad_or_trim_array(mse_embeddings_norm, total_frames)
mse_posture_norm = pad_or_trim_array(mse_posture_norm, total_frames)
mse_voice_norm = pad_or_trim_array(mse_voice_norm, total_frames)
# Create a 3D array for the heatmap (height, width, channels)
heatmap_height = 3 # Assuming you want 3 rows in your heatmap
heatmap_frame = np.zeros((heatmap_height, width, 3), dtype=np.uint8)
# Fill the heatmap frame with color based on MSE values
heatmap_frame[0, :, 0] = (mse_embeddings_norm[frame_index] * 255).astype(np.uint8) # Red channel for facial features
heatmap_frame[1, :, 1] = (mse_posture_norm[frame_index] * 255).astype(np.uint8) # Green channel for body posture
heatmap_frame[2, :, 2] = (mse_voice_norm[frame_index] * 255).astype(np.uint8) # Blue channel for voice
return heatmap_frame
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, most_frequent_person_frames):
print(f"Creating heatmap video. Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
output_filename = os.path.basename(video_path).rsplit('.', 1)[0] + '_heatmap.mp4'
heatmap_video_path = os.path.join(output_folder, output_filename)
print(f"Heatmap video will be saved at: {heatmap_video_path}")
# Load the original video
video = VideoFileClip(video_path)
# Get video properties
width, height = video.w, video.h
total_frames = int(video.duration * video.fps)
# Ensure all MSE arrays have the same length as total_frames
mse_embeddings = np.interp(np.linspace(0, len(mse_embeddings) - 1, total_frames),
np.arange(len(mse_embeddings)), mse_embeddings)
mse_posture = np.interp(np.linspace(0, len(mse_posture) - 1, total_frames),
np.arange(len(mse_posture)), mse_posture)
mse_voice = np.interp(np.linspace(0, len(mse_voice) - 1, total_frames),
np.arange(len(mse_voice)), mse_voice)
print(f"Total frames: {total_frames}")
print(f"mse_embeddings length: {len(mse_embeddings)}")
print(f"mse_posture length: {len(mse_posture)}")
print(f"mse_voice length: {len(mse_voice)}")
# Filter MSE arrays for the most frequent person frames
mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered = filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voice, most_frequent_person_frames)
def combine_video_and_heatmap(t):
video_frame = video.get_frame(t)
heatmap_frame = create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, video.fps, total_frames, width)
heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0]))
# Ensure both frames have the same number of channels
if video_frame.shape[2] != heatmap_frame_resized.shape[2]:
heatmap_frame_resized = cv2.cvtColor(heatmap_frame_resized, cv2.COLOR_RGB2BGR)
combined_frame = np.vstack((video_frame, heatmap_frame_resized))
return combined_frame
final_clip = VideoClip(combine_video_and_heatmap, duration=video.duration)
final_clip = final_clip.set_audio(video.audio)
# Write the final video using x264 codec
final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video.fps, preset='medium', ffmpeg_params=['-crf', '23'])
# Close the video clips
video.close()
final_clip.close()
if os.path.exists(heatmap_video_path):
print(f"Heatmap video created at: {heatmap_video_path}")
print(f"Heatmap video size: {os.path.getsize(heatmap_video_path)} bytes")
return heatmap_video_path
else:
print(f"Failed to create heatmap video at: {heatmap_video_path}")
return None
def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice):
data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T
df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"])
corr = df.corr()
plt.figure(figsize=(10, 8), dpi=300)
heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
plt.title('Correlation Heatmap of MSEs')
plt.tight_layout()
return plt.gcf()