import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas import matplotlib.colors as mcolors from matplotlib.colors import LinearSegmentedColormap import seaborn as sns import numpy as np import pandas as pd import cv2 from moviepy.editor import VideoFileClip, AudioFileClip, CompositeVideoClip, ImageClip, VideoClip, concatenate_videoclips from moviepy.video.fx.all import resize from PIL import Image, ImageDraw, ImageFont from matplotlib.patches import Rectangle from utils import seconds_to_timecode from anomaly_detection import determine_anomalies from scipy import interpolate import gradio as gr import os def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4): plt.figure(figsize=(16, 8), dpi=300) fig, ax = plt.subplots(figsize=(16, 8)) if 'Seconds' not in df.columns: df['Seconds'] = df['Timecode'].apply( lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':'))))) # Ensure df and mse_values have the same length and remove NaN values min_length = min(len(df), len(mse_values)) df = df.iloc[:min_length].copy() mse_values = mse_values[:min_length] # Remove NaN values and create a mask for valid data valid_mask = ~np.isnan(mse_values) df = df[valid_mask] mse_values = mse_values[valid_mask] # Function to identify continuous segments def get_continuous_segments(seconds, values, max_gap=1): segments = [] current_segment = [] for i, (sec, val) in enumerate(zip(seconds, values)): if not current_segment or (sec - current_segment[-1][0] <= max_gap): current_segment.append((sec, val)) else: segments.append(current_segment) current_segment = [(sec, val)] if current_segment: segments.append(current_segment) return segments # Get continuous segments segments = get_continuous_segments(df['Seconds'], mse_values) # Plot each segment separately for segment in segments: segment_seconds, segment_mse = zip(*segment) ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5) # Calculate and plot rolling mean and std for this segment if len(segment) > 1: # Only if there's more than one point in the segment segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse}) segment_df = segment_df.sort_values('Seconds') mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean() std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std() ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5) ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1) # Rest of the function remains the same median = np.median(mse_values) ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline') threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values) ax.axhline(y=threshold, color='red', linestyle='--', label=f'Anomaly Threshold') ax.text(ax.get_xlim()[1], threshold, f'Anomaly Threshold', verticalalignment='center', horizontalalignment='left', color='red') anomalies = determine_anomalies(mse_values, anomaly_threshold) anomaly_frames = df['Frame'].iloc[anomalies].tolist() ax.scatter(df['Seconds'].iloc[anomalies], mse_values[anomalies], color='red', s=20, zorder=5) anomaly_data = list(zip(df['Timecode'].iloc[anomalies], df['Seconds'].iloc[anomalies], mse_values[anomalies])) anomaly_data.sort(key=lambda x: x[1]) grouped_anomalies = [] current_group = [] for timecode, sec, mse in anomaly_data: if not current_group or sec - current_group[-1][1] <= time_threshold: current_group.append((timecode, sec, mse)) else: grouped_anomalies.append(current_group) current_group = [(timecode, sec, mse)] if current_group: grouped_anomalies.append(current_group) for group in grouped_anomalies: start_sec = group[0][1] end_sec = group[-1][1] rect = Rectangle((start_sec, ax.get_ylim()[0]), end_sec - start_sec, ax.get_ylim()[1] - ax.get_ylim()[0], facecolor='red', alpha=0.2, zorder=1) ax.add_patch(rect) for group in grouped_anomalies: highest_mse_anomaly = max(group, key=lambda x: x[2]) timecode, sec, mse = highest_mse_anomaly ax.annotate(timecode, (sec, mse), textcoords="offset points", xytext=(0, 10), ha='center', fontsize=6, color='red') max_seconds = df['Seconds'].max() num_ticks = 100 tick_locations = np.linspace(0, max_seconds, num_ticks) tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6) ax.set_xlabel('Timecode') ax.set_ylabel('Mean Squared Error') ax.set_title(title) ax.grid(True, linestyle='--', alpha=0.7) ax.legend() plt.tight_layout() plt.close() return fig, anomaly_frames def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'): plt.figure(figsize=(16, 3), dpi=300) fig, ax = plt.subplots(figsize=(16, 3)) ax.hist(mse_values, bins=100, edgecolor='black', color=color, alpha=0.7) ax.set_xlabel('Mean Squared Error') ax.set_ylabel('Number of Frames') ax.set_title(title) mean = np.mean(mse_values) std = np.std(mse_values) threshold = mean + anomaly_threshold * std ax.axvline(x=threshold, color='red', linestyle='--', linewidth=2) plt.tight_layout() plt.close() return fig def plot_mse_heatmap(mse_values, title, df): plt.figure(figsize=(20, 3), dpi=300) fig, ax = plt.subplots(figsize=(20, 3)) # Reshape MSE values to 2D array for heatmap mse_2d = mse_values.reshape(1, -1) # Create heatmap sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax) # Set x-axis ticks to timecodes num_ticks = min(60, len(mse_values)) tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int) # Ensure tick_locations are within bounds tick_locations = tick_locations[tick_locations < len(df)] tick_labels = [df['Timecode'].iloc[i] if i < len(df) else '' for i in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top') ax.set_title(title) # Remove y-axis labels ax.set_yticks([]) plt.tight_layout() plt.close() return fig def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3): plt.figure(figsize=(16, 8), dpi=300) fig, ax = plt.subplots(figsize=(16, 8)) df['Seconds'] = df['Timecode'].apply( lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':'))))) posture_data = [(frame, score) for frame, score in posture_scores.items() if score is not None] posture_frames, posture_scores = zip(*posture_data) # Create a new dataframe for posture data posture_df = pd.DataFrame({'Frame': posture_frames, 'Score': posture_scores}) posture_df = posture_df.merge(df[['Frame', 'Seconds']], on='Frame', how='inner') ax.scatter(posture_df['Seconds'], posture_df['Score'], color=color, alpha=0.3, s=5) mean = posture_df['Score'].rolling(window=10).mean() ax.plot(posture_df['Seconds'], mean, color=color, linewidth=0.5) ax.set_xlabel('Timecode') ax.set_ylabel('Posture Score') ax.set_title("Body Posture Over Time") ax.grid(True, linestyle='--', alpha=0.7) max_seconds = df['Seconds'].max() num_ticks = 80 tick_locations = np.linspace(0, max_seconds, num_ticks) tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6) plt.tight_layout() plt.close() return fig def filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voice, most_frequent_person_frames): # Ensure most_frequent_person_frames is a list if not isinstance(most_frequent_person_frames, (list, np.ndarray)): most_frequent_person_frames = [most_frequent_person_frames] # Ensure df and mse arrays have the same length min_length = min(len(df), len(mse_embeddings), len(mse_posture), len(mse_voice)) df = df.iloc[:min_length].copy() mse_embeddings = mse_embeddings[:min_length] mse_posture = mse_posture[:min_length] mse_voice = mse_voice[:min_length] # Create a mask for the most frequent person frames mask = df['Frame'].isin(most_frequent_person_frames) # Pad mask to match the length of the video frames padded_mask = np.zeros(len(mse_embeddings), dtype=bool) padded_mask[:len(mask)] = mask # Apply the mask to filter the MSE arrays mse_embeddings_filtered = np.where(padded_mask, mse_embeddings, 0) mse_posture_filtered = np.where(padded_mask, mse_posture, 0) mse_voice_filtered = np.where(padded_mask, mse_voice, 0) return mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_folder, desired_fps, most_frequent_person_frames): print(f"Creating heatmap video. Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) output_filename = os.path.basename(video_path).rsplit('.', 1)[0] + '_heatmap.mp4' heatmap_video_path = os.path.join(output_folder, output_filename) print(f"Heatmap video will be saved at: {heatmap_video_path}") # Load the original video video = VideoFileClip(video_path) # Get video properties width, height = video.w, video.h total_frames = int(video.duration * video.fps) # Ensure all MSE arrays have the same length as total_frames mse_embeddings = np.interp(np.linspace(0, len(mse_embeddings) - 1, total_frames), np.arange(len(mse_embeddings)), mse_embeddings) mse_posture = np.interp(np.linspace(0, len(mse_posture) - 1, total_frames), np.arange(len(mse_posture)), mse_posture) mse_voice = np.interp(np.linspace(0, len(mse_voice) - 1, total_frames), np.arange(len(mse_voice)), mse_voice) # Filter MSE arrays for the most frequent person frames mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered = filter_mse_for_most_frequent_person(df, mse_embeddings, mse_posture, mse_voice, most_frequent_person_frames) def combine_video_and_heatmap(t): video_frame = video.get_frame(t) heatmap_frame = create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, video.fps, total_frames, width) heatmap_frame_resized = cv2.resize(heatmap_frame, (width, heatmap_frame.shape[0])) combined_frame = np.vstack((video_frame, heatmap_frame_resized)) return combined_frame final_clip = VideoClip(combine_video_and_heatmap, duration=video.duration) final_clip = final_clip.set_audio(video.audio) # Write the final video final_clip.write_videofile(heatmap_video_path, codec='libx264', audio_codec='aac', fps=video.fps) # Close the video clips video.close() final_clip.close() if os.path.exists(heatmap_video_path): print(f"Heatmap video created at: {heatmap_video_path}") print(f"Heatmap video size: {os.path.getsize(heatmap_video_path)} bytes") return heatmap_video_path else: print(f"Failed to create heatmap video at: {heatmap_video_path}") return None def create_heatmap(t, mse_embeddings_filtered, mse_posture_filtered, mse_voice_filtered, fps, total_frames, width): # Normalize the MSE values mse_embeddings_norm = normalize_mse(mse_embeddings_filtered) mse_posture_norm = normalize_mse(mse_posture_filtered) mse_voice_norm = normalize_mse(mse_voice_filtered) # Debug prints print(f"mse_embeddings_norm shape: {mse_embeddings_norm.shape}") print(f"mse_posture_norm shape: {mse_posture_norm.shape}") print(f"mse_voice_norm shape: {mse_voice_norm.shape}") # Ensure combined_mse has the correct shape combined_mse = np.zeros((total_frames, width)) # Adjust shapes and pad with zeros if necessary mse_embeddings_norm = pad_or_trim_array(mse_embeddings_norm, width) mse_posture_norm = pad_or_trim_array(mse_posture_norm, width) mse_voice_norm = pad_or_trim_array(mse_voice_norm, width) combined_mse[0] = mse_embeddings_norm # Assuming you combine posture and voice MSEs similarly combined_mse[1] = mse_posture_norm combined_mse[2] = mse_voice_norm # Return or use combined_mse as needed return combined_mse def normalize_mse(mse): # Your normalization logic here return mse / np.max(mse) def pad_or_trim_array(arr, target_length): if len(arr) > target_length: # Trim the array return arr[:target_length] elif len(arr) < target_length: # Pad the array with zeros return np.pad(arr, (0, target_length - len(arr)), 'constant') return arr def plot_correlation_heatmap(mse_embeddings, mse_posture, mse_voice): data = np.vstack((mse_embeddings, mse_posture, mse_voice)).T df = pd.DataFrame(data, columns=["Facial Features", "Body Posture", "Voice"]) corr = df.corr() plt.figure(figsize=(10, 8), dpi=300) heatmap = sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1) plt.title('Correlation Heatmap of MSEs') plt.tight_layout() return plt.gcf()