import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg import matplotlib.colors as mcolors from matplotlib.colors import LinearSegmentedColormap import seaborn as sns import numpy as np import pandas as pd import cv2 from matplotlib.patches import Rectangle from utils import seconds_to_timecode from anomaly_detection import determine_anomalies import gradio as gr def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4): plt.figure(figsize=(16, 8), dpi=300) fig, ax = plt.subplots(figsize=(16, 8)) if 'Seconds' not in df.columns: df['Seconds'] = df['Timecode'].apply( lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':'))))) # Ensure df and mse_values have the same length and remove NaN values min_length = min(len(df), len(mse_values)) df = df.iloc[:min_length].copy() mse_values = mse_values[:min_length] # Remove NaN values and create a mask for valid data valid_mask = ~np.isnan(mse_values) df = df[valid_mask] mse_values = mse_values[valid_mask] # Function to identify continuous segments def get_continuous_segments(seconds, values, max_gap=1): segments = [] current_segment = [] for i, (sec, val) in enumerate(zip(seconds, values)): if not current_segment or (sec - current_segment[-1][0] <= max_gap): current_segment.append((sec, val)) else: segments.append(current_segment) current_segment = [(sec, val)] if current_segment: segments.append(current_segment) return segments # Get continuous segments segments = get_continuous_segments(df['Seconds'], mse_values) # Plot each segment separately for segment in segments: segment_seconds, segment_mse = zip(*segment) ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5) # Calculate and plot rolling mean and std for this segment if len(segment) > 1: # Only if there's more than one point in the segment segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse}) segment_df = segment_df.sort_values('Seconds') mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean() std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std() ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5) ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1) # Rest of the function remains the same median = np.median(mse_values) ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline') threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values) ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}') ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red') anomalies = determine_anomalies(mse_values, anomaly_threshold) anomaly_frames = df['Frame'].iloc[anomalies].tolist() ax.scatter(df['Seconds'].iloc[anomalies], mse_values[anomalies], color='red', s=20, zorder=5) anomaly_data = list(zip(df['Timecode'].iloc[anomalies], df['Seconds'].iloc[anomalies], mse_values[anomalies])) anomaly_data.sort(key=lambda x: x[1]) grouped_anomalies = [] current_group = [] for timecode, sec, mse in anomaly_data: if not current_group or sec - current_group[-1][1] <= time_threshold: current_group.append((timecode, sec, mse)) else: grouped_anomalies.append(current_group) current_group = [(timecode, sec, mse)] if current_group: grouped_anomalies.append(current_group) for group in grouped_anomalies: start_sec = group[0][1] end_sec = group[-1][1] rect = Rectangle((start_sec, ax.get_ylim()[0]), end_sec - start_sec, ax.get_ylim()[1] - ax.get_ylim()[0], facecolor='red', alpha=0.2, zorder=1) ax.add_patch(rect) for group in grouped_anomalies: highest_mse_anomaly = max(group, key=lambda x: x[2]) timecode, sec, mse = highest_mse_anomaly ax.annotate(timecode, (sec, mse), textcoords="offset points", xytext=(0, 10), ha='center', fontsize=6, color='red') max_seconds = df['Seconds'].max() num_ticks = 100 tick_locations = np.linspace(0, max_seconds, num_ticks) tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6) ax.set_xlabel('Timecode') ax.set_ylabel('Mean Squared Error') ax.set_title(title) ax.grid(True, linestyle='--', alpha=0.7) ax.legend() plt.tight_layout() plt.close() return fig, anomaly_frames def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'): plt.figure(figsize=(16, 3), dpi=300) fig, ax = plt.subplots(figsize=(16, 3)) ax.hist(mse_values, bins=100, edgecolor='black', color=color, alpha=0.7) ax.set_xlabel('Mean Squared Error') ax.set_ylabel('Number of Samples') ax.set_title(title) mean = np.mean(mse_values) std = np.std(mse_values) threshold = mean + anomaly_threshold * std ax.axvline(x=threshold, color='red', linestyle='--', linewidth=2) plt.tight_layout() plt.close() return fig def plot_mse_heatmap(mse_values, title, df): plt.figure(figsize=(20, 3), dpi=300) fig, ax = plt.subplots(figsize=(20, 3)) # Reshape MSE values to 2D array for heatmap mse_2d = mse_values.reshape(1, -1) # Create heatmap sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax) # Set x-axis ticks to timecodes num_ticks = 60 tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int) tick_labels = [df['Timecode'].iloc[i] for i in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top') ax.set_title(title) # Remove y-axis labels ax.set_yticks([]) plt.tight_layout() plt.close() return fig def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3): plt.figure(figsize=(16, 8), dpi=300) fig, ax = plt.subplots(figsize=(16, 8)) df['Seconds'] = df['Timecode'].apply( lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':'))))) posture_data = [(frame, score) for frame, score in posture_scores.items() if score is not None] posture_frames, posture_scores = zip(*posture_data) # Create a new dataframe for posture data posture_df = pd.DataFrame({'Frame': posture_frames, 'Score': posture_scores}) posture_df = posture_df.merge(df[['Frame', 'Seconds']], on='Frame', how='inner') ax.scatter(posture_df['Seconds'], posture_df['Score'], color=color, alpha=0.3, s=5) mean = posture_df['Score'].rolling(window=10).mean() ax.plot(posture_df['Seconds'], mean, color=color, linewidth=0.5) ax.set_xlabel('Timecode') ax.set_ylabel('Posture Score') ax.set_title("Body Posture Over Time") ax.grid(True, linestyle='--', alpha=0.7) max_seconds = df['Seconds'].max() num_ticks = 80 tick_locations = np.linspace(0, max_seconds, num_ticks) tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6) plt.tight_layout() plt.close() return fig def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_path, desired_fps, largest_cluster): # Filter the DataFrame to only include frames from the largest cluster df_largest_cluster = df[df['Cluster'] == largest_cluster] mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster] mse_posture = mse_posture[df['Cluster'] == largest_cluster] mse_voice = mse_voice[df['Cluster'] == largest_cluster] cap = cv2.VideoCapture(video_path) original_fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, original_fps, (width, height + 200)) mse_embeddings = np.interp(np.linspace(0, len(mse_embeddings) - 1, total_frames), np.arange(len(mse_embeddings)), mse_embeddings) mse_posture = np.interp(np.linspace(0, len(mse_posture) - 1, total_frames), np.arange(len(mse_posture)), mse_posture) mse_voice = np.interp(np.linspace(0, len(mse_voice) - 1, total_frames), np.arange(len(mse_voice)), mse_voice) mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings)) mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture)) mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice)) combined_mse = np.zeros((2, total_frames)) combined_mse[0] = mse_embeddings_norm # Use normalized MSE values for facial combined_mse[1] = mse_posture_norm # Use normalized MSE values for posture combined_mse[2] = mse_voice_norm # Custom colormap definition cdict = { 'red': [(0.0, 0.5, 0.5), # Low MSE: 50% red (gray) (1.0, 1.0, 1.0)], # High MSE: Full red 'green': [(0.0, 0.5, 0.5), # Low MSE: 50% green (gray) (1.0, 0.0, 0.0)], # High MSE: No green 'blue': [(0.0, 0.5, 0.5), # Low MSE: 50% blue (gray) (1.0, 0.0, 0.0)] # High MSE: No blue } custom_cmap = LinearSegmentedColormap('custom_cmap', segmentdata=cdict, N=256) fig, ax = plt.subplots(figsize=(width/100, 2)) # Use the custom colormap in the heatmap generation im = ax.imshow(combined_mse, aspect='auto', cmap=custom_cmap, extent=[0, total_frames, 0, 3]) ax.set_yticks([0.5, 1.5, 2.5]) ax.set_yticklabels(['Face', 'Posture', 'Voice']) ax.set_xticks([]) plt.tight_layout() line = None # Add progress tracking progress(0.9, desc="Generating video with heatmap") for frame_count in range(total_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count) ret, frame = cap.read() if not ret: break if line: line.remove() line = ax.axvline(x=frame_count, color='blue', linewidth=3) canvas = FigureCanvasAgg(fig) canvas.draw() heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,)) heatmap_img = cv2.resize(heatmap_img, (width, 200)) # Convert heatmap_img from RGB to BGR heatmap_img = cv2.cvtColor(heatmap_img, cv2.COLOR_RGB2BGR) combined_frame = np.vstack((frame, heatmap_img)) seconds = frame_count / original_fps timecode = f"{int(seconds//3600):02d}:{int((seconds%3600)//60):02d}:{int(seconds%60):02d}" cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) out.write(combined_frame) # Update progress progress(0.9 + (0.1 * (frame_count + 1) / total_frames), desc="Generating video with heatmap") cap.release() out.release() plt.close(fig) return output_path