Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_agg import FigureCanvasAgg | |
import matplotlib.colors as mcolors | |
from matplotlib.colors import LinearSegmentedColormap | |
import seaborn as sns | |
import numpy as np | |
import pandas as pd | |
import cv2 | |
from matplotlib.patches import Rectangle | |
from utils import seconds_to_timecode | |
from anomaly_detection import determine_anomalies | |
import gradio as gr | |
def plot_mse(df, mse_values, title, color='navy', time_threshold=3, anomaly_threshold=4): | |
plt.figure(figsize=(16, 8), dpi=300) | |
fig, ax = plt.subplots(figsize=(16, 8)) | |
if 'Seconds' not in df.columns: | |
df['Seconds'] = df['Timecode'].apply( | |
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':'))))) | |
# Ensure df and mse_values have the same length and remove NaN values | |
min_length = min(len(df), len(mse_values)) | |
df = df.iloc[:min_length].copy() | |
mse_values = mse_values[:min_length] | |
# Remove NaN values and create a mask for valid data | |
valid_mask = ~np.isnan(mse_values) | |
df = df[valid_mask] | |
mse_values = mse_values[valid_mask] | |
# Function to identify continuous segments | |
def get_continuous_segments(seconds, values, max_gap=1): | |
segments = [] | |
current_segment = [] | |
for i, (sec, val) in enumerate(zip(seconds, values)): | |
if not current_segment or (sec - current_segment[-1][0] <= max_gap): | |
current_segment.append((sec, val)) | |
else: | |
segments.append(current_segment) | |
current_segment = [(sec, val)] | |
if current_segment: | |
segments.append(current_segment) | |
return segments | |
# Get continuous segments | |
segments = get_continuous_segments(df['Seconds'], mse_values) | |
# Plot each segment separately | |
for segment in segments: | |
segment_seconds, segment_mse = zip(*segment) | |
ax.scatter(segment_seconds, segment_mse, color=color, alpha=0.3, s=5) | |
# Calculate and plot rolling mean and std for this segment | |
if len(segment) > 1: # Only if there's more than one point in the segment | |
segment_df = pd.DataFrame({'Seconds': segment_seconds, 'MSE': segment_mse}) | |
segment_df = segment_df.sort_values('Seconds') | |
mean = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).mean() | |
std = segment_df['MSE'].rolling(window=min(10, len(segment)), min_periods=1, center=True).std() | |
ax.plot(segment_df['Seconds'], mean, color=color, linewidth=0.5) | |
ax.fill_between(segment_df['Seconds'], mean - std, mean + std, color=color, alpha=0.1) | |
# Rest of the function remains the same | |
median = np.median(mse_values) | |
ax.axhline(y=median, color='black', linestyle='--', label='Median Baseline') | |
threshold = np.mean(mse_values) + anomaly_threshold * np.std(mse_values) | |
ax.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {anomaly_threshold:.1f}') | |
ax.text(ax.get_xlim()[1], threshold, f'Threshold: {anomaly_threshold:.1f}', verticalalignment='center', horizontalalignment='left', color='red') | |
anomalies = determine_anomalies(mse_values, anomaly_threshold) | |
anomaly_frames = df['Frame'].iloc[anomalies].tolist() | |
ax.scatter(df['Seconds'].iloc[anomalies], mse_values[anomalies], color='red', s=20, zorder=5) | |
anomaly_data = list(zip(df['Timecode'].iloc[anomalies], | |
df['Seconds'].iloc[anomalies], | |
mse_values[anomalies])) | |
anomaly_data.sort(key=lambda x: x[1]) | |
grouped_anomalies = [] | |
current_group = [] | |
for timecode, sec, mse in anomaly_data: | |
if not current_group or sec - current_group[-1][1] <= time_threshold: | |
current_group.append((timecode, sec, mse)) | |
else: | |
grouped_anomalies.append(current_group) | |
current_group = [(timecode, sec, mse)] | |
if current_group: | |
grouped_anomalies.append(current_group) | |
for group in grouped_anomalies: | |
start_sec = group[0][1] | |
end_sec = group[-1][1] | |
rect = Rectangle((start_sec, ax.get_ylim()[0]), end_sec - start_sec, ax.get_ylim()[1] - ax.get_ylim()[0], | |
facecolor='red', alpha=0.2, zorder=1) | |
ax.add_patch(rect) | |
for group in grouped_anomalies: | |
highest_mse_anomaly = max(group, key=lambda x: x[2]) | |
timecode, sec, mse = highest_mse_anomaly | |
ax.annotate(timecode, (sec, mse), textcoords="offset points", xytext=(0, 10), | |
ha='center', fontsize=6, color='red') | |
max_seconds = df['Seconds'].max() | |
num_ticks = 100 | |
tick_locations = np.linspace(0, max_seconds, num_ticks) | |
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations] | |
ax.set_xticks(tick_locations) | |
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6) | |
ax.set_xlabel('Timecode') | |
ax.set_ylabel('Mean Squared Error') | |
ax.set_title(title) | |
ax.grid(True, linestyle='--', alpha=0.7) | |
ax.legend() | |
plt.tight_layout() | |
plt.close() | |
return fig, anomaly_frames | |
def plot_mse_histogram(mse_values, title, anomaly_threshold, color='blue'): | |
plt.figure(figsize=(16, 3), dpi=300) | |
fig, ax = plt.subplots(figsize=(16, 3)) | |
ax.hist(mse_values, bins=100, edgecolor='black', color=color, alpha=0.7) | |
ax.set_xlabel('Mean Squared Error') | |
ax.set_ylabel('Number of Samples') | |
ax.set_title(title) | |
mean = np.mean(mse_values) | |
std = np.std(mse_values) | |
threshold = mean + anomaly_threshold * std | |
ax.axvline(x=threshold, color='red', linestyle='--', linewidth=2) | |
plt.tight_layout() | |
plt.close() | |
return fig | |
def plot_mse_heatmap(mse_values, title, df): | |
plt.figure(figsize=(20, 3), dpi=300) | |
fig, ax = plt.subplots(figsize=(20, 3)) | |
# Reshape MSE values to 2D array for heatmap | |
mse_2d = mse_values.reshape(1, -1) | |
# Create heatmap | |
sns.heatmap(mse_2d, cmap='YlOrRd', cbar=False, ax=ax) | |
# Set x-axis ticks to timecodes | |
num_ticks = 60 | |
tick_locations = np.linspace(0, len(mse_values) - 1, num_ticks).astype(int) | |
tick_labels = [df['Timecode'].iloc[i] for i in tick_locations] | |
ax.set_xticks(tick_locations) | |
ax.set_xticklabels(tick_labels, rotation=90, ha='center', va='top') | |
ax.set_title(title) | |
# Remove y-axis labels | |
ax.set_yticks([]) | |
plt.tight_layout() | |
plt.close() | |
return fig | |
def plot_posture(df, posture_scores, color='blue', anomaly_threshold=3): | |
plt.figure(figsize=(16, 8), dpi=300) | |
fig, ax = plt.subplots(figsize=(16, 8)) | |
df['Seconds'] = df['Timecode'].apply( | |
lambda x: sum(float(t) * 60 ** i for i, t in enumerate(reversed(x.split(':'))))) | |
posture_data = [(frame, score) for frame, score in posture_scores.items() if score is not None] | |
posture_frames, posture_scores = zip(*posture_data) | |
# Create a new dataframe for posture data | |
posture_df = pd.DataFrame({'Frame': posture_frames, 'Score': posture_scores}) | |
posture_df = posture_df.merge(df[['Frame', 'Seconds']], on='Frame', how='inner') | |
ax.scatter(posture_df['Seconds'], posture_df['Score'], color=color, alpha=0.3, s=5) | |
mean = posture_df['Score'].rolling(window=10).mean() | |
ax.plot(posture_df['Seconds'], mean, color=color, linewidth=0.5) | |
ax.set_xlabel('Timecode') | |
ax.set_ylabel('Posture Score') | |
ax.set_title("Body Posture Over Time") | |
ax.grid(True, linestyle='--', alpha=0.7) | |
max_seconds = df['Seconds'].max() | |
num_ticks = 80 | |
tick_locations = np.linspace(0, max_seconds, num_ticks) | |
tick_labels = [seconds_to_timecode(int(s)) for s in tick_locations] | |
ax.set_xticks(tick_locations) | |
ax.set_xticklabels(tick_labels, rotation=90, ha='center', fontsize=6) | |
plt.tight_layout() | |
plt.close() | |
return fig | |
def create_video_with_heatmap(video_path, df, mse_embeddings, mse_posture, mse_voice, output_path, desired_fps, largest_cluster): | |
# Filter the DataFrame to only include frames from the largest cluster | |
df_largest_cluster = df[df['Cluster'] == largest_cluster] | |
mse_embeddings = mse_embeddings[df['Cluster'] == largest_cluster] | |
mse_posture = mse_posture[df['Cluster'] == largest_cluster] | |
mse_voice = mse_voice[df['Cluster'] == largest_cluster] | |
cap = cv2.VideoCapture(video_path) | |
original_fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, original_fps, (width, height + 200)) | |
mse_embeddings = np.interp(np.linspace(0, len(mse_embeddings) - 1, total_frames), | |
np.arange(len(mse_embeddings)), mse_embeddings) | |
mse_posture = np.interp(np.linspace(0, len(mse_posture) - 1, total_frames), | |
np.arange(len(mse_posture)), mse_posture) | |
mse_voice = np.interp(np.linspace(0, len(mse_voice) - 1, total_frames), | |
np.arange(len(mse_voice)), mse_voice) | |
mse_embeddings_norm = (mse_embeddings - np.min(mse_embeddings)) / (np.max(mse_embeddings) - np.min(mse_embeddings)) | |
mse_posture_norm = (mse_posture - np.min(mse_posture)) / (np.max(mse_posture) - np.min(mse_posture)) | |
mse_voice_norm = (mse_voice - np.min(mse_voice)) / (np.max(mse_voice) - np.min(mse_voice)) | |
combined_mse = np.zeros((2, total_frames)) | |
combined_mse[0] = mse_embeddings_norm # Use normalized MSE values for facial | |
combined_mse[1] = mse_posture_norm # Use normalized MSE values for posture | |
combined_mse[2] = mse_voice_norm | |
# Custom colormap definition | |
cdict = { | |
'red': [(0.0, 0.5, 0.5), # Low MSE: 50% red (gray) | |
(1.0, 1.0, 1.0)], # High MSE: Full red | |
'green': [(0.0, 0.5, 0.5), # Low MSE: 50% green (gray) | |
(1.0, 0.0, 0.0)], # High MSE: No green | |
'blue': [(0.0, 0.5, 0.5), # Low MSE: 50% blue (gray) | |
(1.0, 0.0, 0.0)] # High MSE: No blue | |
} | |
custom_cmap = LinearSegmentedColormap('custom_cmap', segmentdata=cdict, N=256) | |
fig, ax = plt.subplots(figsize=(width/100, 2)) | |
# Use the custom colormap in the heatmap generation | |
im = ax.imshow(combined_mse, aspect='auto', cmap=custom_cmap, extent=[0, total_frames, 0, 3]) | |
ax.set_yticks([0.5, 1.5, 2.5]) | |
ax.set_yticklabels(['Face', 'Posture', 'Voice']) | |
ax.set_xticks([]) | |
plt.tight_layout() | |
line = None | |
# Add progress tracking | |
progress(0.9, desc="Generating video with heatmap") | |
for frame_count in range(total_frames): | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count) | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if line: | |
line.remove() | |
line = ax.axvline(x=frame_count, color='blue', linewidth=3) | |
canvas = FigureCanvasAgg(fig) | |
canvas.draw() | |
heatmap_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') | |
heatmap_img = heatmap_img.reshape(canvas.get_width_height()[::-1] + (3,)) | |
heatmap_img = cv2.resize(heatmap_img, (width, 200)) | |
# Convert heatmap_img from RGB to BGR | |
heatmap_img = cv2.cvtColor(heatmap_img, cv2.COLOR_RGB2BGR) | |
combined_frame = np.vstack((frame, heatmap_img)) | |
seconds = frame_count / original_fps | |
timecode = f"{int(seconds//3600):02d}:{int((seconds%3600)//60):02d}:{int(seconds%60):02d}" | |
cv2.putText(combined_frame, f"Time: {timecode}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
out.write(combined_frame) | |
# Update progress | |
progress(0.9 + (0.1 * (frame_count + 1) / total_frames), desc="Generating video with heatmap") | |
cap.release() | |
out.release() | |
plt.close(fig) | |
return output_path |