reab5555's picture
Update visualization.py
1d6cb38 verified
raw
history blame
9.65 kB
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import numpy as np
import pandas as pd
import cv2
from moviepy.editor import VideoFileClip, AudioFileClip, CompositeVideoClip, ImageClip, VideoClip, concatenate_videoclips
from moviepy.video.fx.all import resize
from moviepy.video.io.bindings import mplfig_to_npimage
from PIL import Image, ImageDraw, ImageFont
from matplotlib.patches import Rectangle
from utils import seconds_to_timecode
from anomaly_detection import determine_anomalies
from scipy import interpolate
import librosa
import librosa.display
import gradio as gr
import os
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4):
plt.figure(figsize=(16, 8), dpi=300)
fig, ax = plt.subplots(figsize=(16, 8))
if 'Seconds' not in df.columns:
df['Seconds'] = df['Timecode'].apply(
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
# Ensure df and mse_values have the same length and remove NaN values
min_length = min(len(df), len(mse_values))
df = df.iloc[:min_length].copy()
mse_values = mse_values[:min_length]
# Remove NaN values and create a mask for valid data
valid_mask = ~np.isnan(mse_values)
df = df[valid_mask]
mse_values = mse_values[valid_mask]
# Function to identify continuous segments
def get_continuous_segments(seconds, values, max_gap=1):
segments = []
current_segment = []
for i, (sec, val) in enumerate(zip(seconds, values)):
if not current_segment or (sec - current_segment[-1][0] <= max_gap):
current_segment.append((sec, val))
else:
segments.append(current_segment)
current_segment = [(sec, val)]
if current_segment:
segments.append(current_segment)
return segments
# Get continuous segments
segments = get_continuous_segments(df['Seconds'], mse_values)
# Plot each segment separately
for segment in segments:
segment_seconds, segment_mse = zip(*segment)
ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5)
# Calculate and plot rolling mean and std for this segment
if len(segment) > 1: # Only if there's more than one point in the segment
segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse})
segment_df = segment_df.sort_values('Seconds')
mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean()
std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std()
ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5)
ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1)
median = np.median(mse_values)
ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline')
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values)
ax.axhline(y=threshold, color='red', linestyle='--', label=f'Anomaly Threshold')
ax.text(ax.get_xlim()[1], threshold, f'Anomaly Threshold', verticalalignment='center', horizontalalignment='left', color='red')
anomalies = determine_anomalies(mse_values, anomaly_threshold)
anomaly_frames = df['Frame'].iloc[anomalies].tolist()
ax.scatter(df['Seconds'].iloc[anomalies], mse_values[anomalies], color='red', s=20, zorder=5)
anomaly_data = list(zip(df['Timecode'].iloc[anomalies],
df['Seconds'].iloc[anomalies],
mse_values[anomalies]))
anomaly_data.sort(key=lambda x: x[1])
max_seconds = df['Seconds'].max()
num_ticks = 80
tick_locations = np.linspace(0, max_seconds, num_ticks)
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
ax.set_xlabel('Timecode')
ax.set_ylabel('Mean Squared Error')
ax.set_title(title)
ax.grid(True, linestyle='--', alpha=0.7)
ax.legend()
plt.tight_layout()
plt.close()
return fig, anomaly_frames
def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'):
plt.figure(figsize=(16, 3), dpi=300)
fig, ax = plt.subplots(figsize=(16, 3))
ax.hist(mse_values, bins=100, edgecolor='black', color=color, alpha=0.7)
ax.set_xlabel('Mean Squared Error')
ax.set_ylabel('Number of Frames')
ax.set_title(title)
mean = np.mean(mse_values)
std = np.std(mse_values)
threshold = mean + anomaly_threshold * std
ax.axvline(x=threshold, color='red', linestyle='--', linewidth=2)
plt.tight_layout()
plt.close()
return fig
def plot_mse_heatmap(mse_values, title, df):
plt.figure(figsize=(20, 3), dpi=300)
fig, ax = plt.subplots(figsize=(20, 3))
# Reshape MSE values to 2D array for heatmap
mse_2d = mse_values.reshape(1, -1)
# Create heatmap
sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax)
# Set x-axis ticks to timecodes
num_ticks = min(60, len(mse_values))
tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int)
# Ensure tick_locations are within bounds
tick_locations = tick_locations[tick_locations < len(df)]
tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
ax.set_title(title)
# Remove y-axis labels
ax.set_yticks([])
plt.tight_layout()
plt.close()
return fig
def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3):
plt.figure(figsize=(16, 8), dpi=300)
fig, ax = plt.subplots(figsize=(16, 8))
df['Seconds'] = df['Timecode'].apply(
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':')))))
posture_data = [(frame, score) for frame, score in posture_scores.items() if score is not None]
posture_frames, posture_scores = zip(*posture_data)
# Create a new dataframe for posture data
posture_df = pd.DataFrame({'Frame': posture_frames, 'Score': posture_scores})
posture_df = posture_df.merge(df[['Frame', 'Seconds']], on='Frame', how='inner')
ax.scatter(posture_df['Seconds'], posture_df['Score'], color=color, alpha=0.3, s=5)
mean = posture_df['Score'].rolling(window=10).mean()
ax.plot(posture_df['Seconds'], mean, color=color, linewidth=0.5)
ax.set_xlabel('Timecode')
ax.set_ylabel('Posture Score')
ax.set_title("Body Posture Over Time")
ax.grid(True, linestyle='--', alpha=0.7)
max_seconds = df['Seconds'].max()
num_ticks = 80
tick_locations = np.linspace(0, max_seconds, num_ticks)
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations]
ax.set_xticks(tick_locations)
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6)
plt.tight_layout()
plt.close()
return fig
def plot_stacked_mse_heatmaps(mse_face, mse_posture, mse_voice, df, title="Combined MSE Heatmaps"):
plt.figure(figsize=(20, 6), dpi=300)
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 8), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1.2], 'hspace': 0})
# Face heatmap
sns.heatmap(mse_face.reshape(1, -1), cmap='Reds', cbar=False, ax=ax1, xticklabels=False, yticklabels=False)
ax1.set_ylabel('Face', rotation=0, ha='right', va='center')
ax1.yaxis.set_label_coords(-0.01, 0.5)
# Posture heatmap
sns.heatmap(mse_posture.reshape(1, -1), cmap='Reds', cbar=False, ax=ax2, xticklabels=False, yticklabels=False)
ax2.set_ylabel('Posture', rotation=0, ha='right', va='center')
ax2.yaxis.set_label_coords(-0.01, 0.5)
# Voice heatmap
sns.heatmap(mse_voice.reshape(1, -1), cmap='Reds', cbar=False, ax=ax3, yticklabels=False)
ax3.set_ylabel('Voice', rotation=0, ha='right', va='center')
ax3.yaxis.set_label_coords(-0.01, 0.5)
# Set x-axis ticks to timecodes for the bottom subplot
num_ticks = min(60, len(mse_voice))
tick_locations = np.linspace(0, len(mse_voice) - 1, num_ticks).astype(int)
tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations]
ax3.set_xticks(tick_locations)
ax3.set_xticklabels(tick_labels, rotation=90, ha='center', va='top')
# Remove spines
for ax in [ax1, ax2, ax3]:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.suptitle(title)
plt.tight_layout()
plt.close()
return fig
def plot_audio_waveform(audio_path, title="Audio Waveform"):
# Load the audio file
y, sr = librosa.load(audio_path)
# Create the plot
plt.figure(figsize=(20, 4))
librosa.display.waveshow(y, sr=sr)
# Set the x-axis to display timecodes
max_time = librosa.get_duration(y=y, sr=sr)
x_ticks = np.arange(0, max_time, max_time/10) # 10 ticks
x_labels = [f"{int(t//3600):02d}:{int((t%3600)//60):02d}:{int(t%60):02d}" for t in x_ticks]
plt.xticks(x_ticks, x_labels, rotation=45)
plt.title(title)
plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.tight_layout()
return plt.gcf()