# ------------------ Import Libraries ------------------ import dash from dash import dcc, html, Input, Output, State, no_update import plotly.graph_objects as go import pandas as pd import numpy as np import cv2 import base64 from scipy.ndimage import gaussian_filter1d import requests import json import tempfile import os from urllib.parse import urljoin import subprocess # ------------------ Data Download and Processing ------------------ class RemoteDatasetLoader: def __init__(self, repo_id: str, timeout: int = 30): self.repo_id = repo_id self.timeout = timeout self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" def _get_dataset_info(self) -> dict: info_url = urljoin(self.base_url, "meta/info.json") response = requests.get(info_url, timeout=self.timeout) response.raise_for_status() return response.json() def _get_episode_info(self, episode_id: int) -> dict: episodes_url = urljoin(self.base_url, "meta/episodes.jsonl") response = requests.get(episodes_url, timeout=self.timeout) response.raise_for_status() episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()] for episode in episodes: if episode.get("episode_index") == episode_id: return episode raise ValueError(f"Episode {episode_id} not found") def _is_valid_mp4(self, file_path): if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100: return False # Use ffprobe to check if it is a valid mp4 try: result = subprocess.run([ 'ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=codec_name', '-of', 'default=noprint_wrappers=1:nokey=1', file_path ], capture_output=True, text=True, timeout=10) if result.returncode == 0 and '264' in result.stdout: return True except Exception as e: print(f"ffprobe video check failed: {e}") return False def _download_video(self, video_url: str, save_path: str) -> str: response = requests.get(video_url, timeout=self.timeout, stream=True) response.raise_for_status() # Check Content-Type if 'video' not in response.headers.get('Content-Type', ''): raise ValueError(f"URL {video_url} does not return video content, Content-Type: {response.headers.get('Content-Type')}") os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return save_path def load_episode_data(self, episode_id: int, video_keys=None, download_dir=None): dataset_info = self._get_dataset_info() self._get_episode_info(episode_id) # Check if episode exists if download_dir is None: download_dir = tempfile.mkdtemp(prefix="lerobot_videos_") if video_keys is None: video_keys = [key for key, feature in dataset_info["features"].items() if feature["dtype"] == "video"] video_keys = video_keys[:2] video_paths = [] chunks_size = dataset_info.get("chunks_size", 1000) # Create repo-specific subdirectory repo_name = self.repo_id.replace('/', '_') # Replace / with _ to avoid path issues repo_dir = os.path.join(download_dir, repo_name) os.makedirs(repo_dir, exist_ok=True) for i, video_key in enumerate(video_keys): video_url = self.base_url + dataset_info["video_path"].format( episode_chunk=episode_id // chunks_size, video_key=video_key, episode_index=episode_id ) video_filename = f"episode_{episode_id}_{video_key}.mp4" local_path = os.path.join(repo_dir, video_filename) # Prefer loading local valid mp4 if self._is_valid_mp4(local_path): print(f"Local valid video found: {local_path}") video_paths.append(local_path) continue try: downloaded_path = self._download_video(video_url, local_path) video_paths.append(downloaded_path) except Exception as e: print(f"Failed to download video {video_key}: {e}") video_paths.append(video_url) data_url = self.base_url + dataset_info["data_path"].format( episode_chunk=episode_id // chunks_size, episode_index=episode_id ) try: df = pd.read_parquet(data_url) except Exception as e: print(f"Failed to load data: {e}") df = pd.DataFrame() return video_paths, df def check_ffmpeg_available(): try: result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True, timeout=5) return result.returncode == 0 except (subprocess.TimeoutExpired, FileNotFoundError): return False def get_video_codec_info(video_path): try: result = subprocess.run([ 'ffprobe', '-v', 'quiet', '-print_format', 'json', '-show_streams', video_path ], capture_output=True, text=True, timeout=10) if result.returncode == 0: info = json.loads(result.stdout) for stream in info.get('streams', []): if stream.get('codec_type') == 'video': return stream.get('codec_name', 'unknown') except Exception as e: print(f"Failed to get video codec info: {e}") return 'unknown' def reencode_video_to_h264(input_path, output_path=None, quality='medium'): if output_path is None: base_name = os.path.splitext(input_path)[0] output_path = f"{base_name}_h264.mp4" quality_params = { 'fast': ['-preset', 'ultrafast', '-crf', '28'], 'medium': ['-preset', 'medium', '-crf', '23'], 'high': ['-preset', 'slow', '-crf', '18'] } params = quality_params.get(quality, quality_params['medium']) try: cmd = [ 'ffmpeg', '-i', input_path, '-c:v', 'libx264', '-c:a', 'aac', '-movflags', '+faststart', '-y', ] + params + [output_path] result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: return output_path else: print(f"Re-encoding failed: {result.stderr}") return input_path except subprocess.TimeoutExpired: print("Re-encoding timeout") return input_path except Exception as e: print(f"Re-encoding exception: {e}") return input_path def process_video_for_compatibility(video_path): if not os.path.exists(video_path): print(f"Video file does not exist: {video_path}") return video_path if not check_ffmpeg_available(): print("ffmpeg not available, skipping re-encoding") return video_path codec = get_video_codec_info(video_path) if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown': reencoded_path = reencode_video_to_h264(video_path, quality='fast') if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024: return reencoded_path else: print("Re-encoding failed, using original file") return video_path else: return video_path def load_remote_dataset(repo_id: str, episode_id: int = 0, video_keys=None, download_dir=None): loader = RemoteDatasetLoader(repo_id) video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir) processed_video_paths = [] for video_path in video_paths: processed_path = process_video_for_compatibility(video_path) processed_video_paths.append(processed_path) return processed_video_paths, df # ------------------ Dash Initialization ------------------ app = dash.Dash(__name__, suppress_callback_exceptions=True) server = app.server # ------------------ Page Layout ------------------ app.layout = html.Div([ # Header with gradient background html.Div([ html.H1("Keyframe Identification", style={ "textAlign": "center", "marginBottom": "10px", "color": "white", "fontSize": "2.5rem", "fontWeight": "300", "textShadow": "2px 2px 4px rgba(0,0,0,0.3)" }), html.P("Interactive Joint Analysis with Video Synchronization", style={ "textAlign": "center", "color": "rgba(255,255,255,0.9)", "fontSize": "1.1rem", "marginBottom": "0" }) ], style={ "background": "linear-gradient(135deg, #667eea 0%, #764ba2 100%)", "padding": "30px 20px", "marginBottom": "30px", "borderRadius": "0 0 15px 15px", "boxShadow": "0 4px 20px rgba(0,0,0,0.1)" }), # Control Panel html.Div([ html.Div([ html.Label("Repository ID:", style={ "fontWeight": "600", "color": "#333", "marginRight": "10px", "fontSize": "1rem" }), dcc.Input( id="input-repo-id", type="text", value="zijian2022/sortingtest", style={ "width": "350px", "padding": "12px 15px", "border": "2px solid #e1e5e9", "borderRadius": "8px", "fontSize": "14px", "transition": "border-color 0.3s ease", "outline": "none" }, placeholder="Enter HuggingFace dataset repository ID" ), ], style={"marginBottom": "15px"}), html.Div([ html.Label("Episode ID:", style={ "fontWeight": "600", "color": "#333", "marginRight": "10px", "fontSize": "1rem" }), dcc.Input( id="input-episode-id", type="number", value=0, min=0, style={ "width": "120px", "padding": "12px 15px", "border": "2px solid #e1e5e9", "borderRadius": "8px", "fontSize": "14px", "transition": "border-color 0.3s ease", "outline": "none" } ), html.Button( "Load Data", id="btn-load", n_clicks=0, style={ "marginLeft": "20px", "padding": "12px 25px", "backgroundColor": "#667eea", "color": "white", "border": "none", "borderRadius": "8px", "fontSize": "14px", "fontWeight": "600", "cursor": "pointer", "transition": "all 0.3s ease", "boxShadow": "0 2px 10px rgba(102, 126, 234, 0.3)" } ), ]), ], style={ "textAlign": "center", "marginBottom": "40px", "padding": "25px", "backgroundColor": "white", "borderRadius": "12px", "boxShadow": "0 4px 20px rgba(0,0,0,0.08)", "border": "1px solid #f0f0f0" }), # Loading and Data Store dcc.Loading( id="loading", type="circle", style={"margin": "20px auto"}, children=dcc.Store(id="store-data") ), # Main Content Area html.Div( id="main-content", style={ "backgroundColor": "#f8f9fa", "minHeight": "400px", "borderRadius": "12px", "padding": "20px" } ), ], style={ "fontFamily": "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", "backgroundColor": "#f5f7fa", "minHeight": "100vh", "padding": "0" }) # ------------------ Data Loading Callback ------------------ @app.callback( Output("store-data", "data"), Input("btn-load", "n_clicks"), State("input-repo-id", "value"), State("input-episode-id", "value"), prevent_initial_call=True ) def load_data_callback(n_clicks, repo_id, episode_id): try: video_paths, data_df = load_remote_dataset( repo_id=repo_id, episode_id=int(episode_id), download_dir="./downloaded_videos" ) if data_df is None or data_df.empty: return {} return { "video_paths": video_paths, "data_df": data_df.to_dict("records"), "columns": ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"], "timestamps": data_df["timestamp"].tolist() } except Exception as e: print(f"Data loading error: {e}") return {} # ------------------ Main Content Rendering Callback ------------------ @app.callback( Output("main-content", "children"), Input("store-data", "data") ) def update_main_content(data): if not data or "data_df" not in data or len(data["data_df"]) == 0: return html.Div([ html.Div("📊", style={"fontSize": "3rem", "marginBottom": "20px"}), html.H3("No Data Available", style={"color": "#666", "marginBottom": "10px"}), html.P("Please click the 'Load Data' button above to get data.", style={"color": "#888", "fontSize": "1rem"}) ], style={ "textAlign": "center", "padding": "60px 20px", "color": "#666" }) columns = data["columns"] rows = [] for i, joint in enumerate(columns): rows.append(html.Div([ # Joint Graph - Left 50% html.Div([ dcc.Graph(id=f"graph-{i}") ], style={ "flex": "0 0 50%", "backgroundColor": "white", "borderRadius": "8px", "padding": "8px", "boxShadow": "0 2px 10px rgba(0,0,0,0.05)", "border": "1px solid #e9ecef", "marginRight": "2%" }), # Video Area - Right 48% html.Div([ html.Img(id=f"video1-{i}", style={ "width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block", "borderRadius": "6px", "border": "2px solid #e9ecef" }), html.Img(id=f"video2-{i}", style={ "width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block", "borderRadius": "6px", "border": "2px solid #e9ecef" }) ], style={ "flex": "0 0 48%" }) ], style={ "marginBottom": "25px", "backgroundColor": "white", "borderRadius": "12px", "padding": "12px", "boxShadow": "0 4px 15px rgba(0,0,0,0.08)", "border": "1px solid #f0f0f0", "display": "flex", "alignItems": "flex-start", "minHeight": "250px" })) return html.Div(rows) # ------------------ Shadow and Highlight Utility Functions ------------------ def find_intervals(mask): intervals = [] start = None for i, val in enumerate(mask): if val and start is None: start = i elif not val and start is not None: intervals.append((start, i - 1)) start = None if start is not None: intervals.append((start, len(mask) - 1)) return intervals def get_shadow_info(joint_name, action_df, delta_t, time_for_plot): angles = action_df[joint_name].values velocity = np.diff(angles) / delta_t smoothed_velocity = gaussian_filter1d(velocity, sigma=1) smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) vel_threshold = 0.5 highlight_width = 1 k = 2 shadows = [] low_speed_mask = np.abs(smoothed_velocity) < vel_threshold low_speed_intervals = find_intervals(low_speed_mask) for start, end in low_speed_intervals: if end - start + 1 <= k: shadows.append({ 'type': 'low_speed', 'start_time': time_for_plot[start], 'end_time': time_for_plot[end], 'start_idx': start, 'end_idx': end }) max_idx = np.argmax(smoothed_angle) s_max = max(0, max_idx - highlight_width) e_max = min(len(time_for_plot) - 1, max_idx + highlight_width) shadows.append({ 'type': 'max_value', 'start_time': time_for_plot[s_max], 'end_time': time_for_plot[e_max], 'start_idx': s_max, 'end_idx': e_max }) min_idx = np.argmin(smoothed_angle) s_min = max(0, min_idx - highlight_width) e_min = min(len(time_for_plot) - 1, min_idx + highlight_width) shadows.append({ 'type': 'min_value', 'start_time': time_for_plot[s_min], 'end_time': time_for_plot[e_min], 'start_idx': s_min, 'end_idx': e_min }) return shadows def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all_shadows): angles = action_df[joint_name].values velocity = np.diff(angles) / delta_t smoothed_velocity = gaussian_filter1d(velocity, sigma=1) smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) shapes = [] current_shadows = all_shadows[joint_name] for shadow in current_shadows: shapes.append({ "type": "rect", "xref": "x", "yref": "paper", "x0": shadow['start_time'], "x1": shadow['end_time'], "y0": 0, "y1": 1, "fillcolor": "#ef4444", # Fixed red "opacity": 0.4, "line": {"width": 0} }) return { "data": [ go.Scatter( x=time_for_plot, y=smoothed_angle, name="Joint Angle", line=dict(color='#f59e0b', width=2), hovertemplate='Time: %{x:.2f}s
Angle: %{y:.2f}°' ) ], "layout": go.Layout( title={ 'text': joint_name.replace('_', ' ').title(), 'font': {'size': 16, 'color': '#374151'} }, xaxis={ "title": "Time (seconds)", "titlefont": {"color": "#6b7280"}, "tickfont": {"color": "#6b7280"}, "gridcolor": "#f3f4f6", "zerolinecolor": "#e5e7eb" }, yaxis={ "title": "Angle (degrees)", "titlefont": {"color": "#6b7280"}, "tickfont": {"color": "#6b7280"}, "gridcolor": "#f3f4f6", "zerolinecolor": "#e5e7eb" }, shapes=shapes, hovermode="x unified", height=220, margin=dict(t=30, b=30, l=50, r=30), showlegend=False, plot_bgcolor='white', paper_bgcolor='white', font={'family': "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif"}, hoverlabel=dict( bgcolor="white", font_size=12, font_family="'Segoe UI', Tahoma, Geneva, Verdana, sans-serif" ) ) } # ------------------ Chart Update Callback ------------------ @app.callback( [Output(f"graph-{i}", "figure") for i in range(6)], [Input("store-data", "data")], prevent_initial_call=True ) def update_all_graphs(data): if not data or "data_df" not in data or len(data["data_df"]) == 0: return [no_update] * 6 columns = data["columns"] df = pd.DataFrame.from_records(data["data_df"]) action_df = pd.DataFrame(df["action"].tolist(), columns=columns) timestamps = df["timestamp"].values delta_t = np.diff(timestamps) time_for_plot = timestamps[1:] all_shadows = {} for joint in columns: all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot) # Generate all charts, no highlight logic return [ generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows) for i, joint in enumerate(columns) ] # ------------------ Video Frame Extraction Function ------------------ def get_video_frame(video_path, time_in_seconds): try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"❌ Cannot open video: {video_path}") return None fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0: cap.release() return None frame_num = int(time_in_seconds * fps) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) success, frame = cap.read() cap.release() if success and frame is not None: height, width = frame.shape[:2] if width > 640: new_width = 640 new_height = int(height * (new_width / width)) frame = cv2.resize(frame, (new_width, new_height)) encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85] _, buffer = cv2.imencode('.jpg', frame, encode_param) encoded = base64.b64encode(buffer).decode('utf-8') return f"data:image/jpeg;base64,{encoded}" else: return None except Exception as e: print(f"❌ Exception extracting video frame: {e}") return None # ------------------ Video Frame Callback ------------------ for i in range(6): @app.callback( Output(f"video1-{i}", "src"), Output(f"video2-{i}", "src"), Input("store-data", "data"), Input(f"graph-{i}", "hoverData"), prevent_initial_call=True ) def update_video_frames(data, hover_data, idx=i): if not data or "data_df" not in data or len(data["data_df"]) == 0: return no_update, no_update columns = data["columns"] df = pd.DataFrame.from_records(data["data_df"]) timestamps = df["timestamp"].values time_for_plot = timestamps[1:] video_paths = data["video_paths"] # Determine the time point to display display_time = 0.0 # Default to start time if hover_data and "points" in hover_data and len(hover_data["points"]) > 0: # If there is hover data, use hover time display_time = float(hover_data["points"][0]["x"]) elif len(time_for_plot) > 0: # If no hover data, use the start time of the timeline display_time = time_for_plot[0] try: frame1 = get_video_frame(video_paths[0], display_time) frame2 = get_video_frame(video_paths[1], display_time) if frame1 and frame2: return frame1, frame2 else: return no_update, no_update except Exception as e: print(f"update_video_frames callback error: {e}") return no_update, no_update # ------------------ Start Application ------------------ if __name__ == "__main__": app.run(debug=True, host='0.0.0.0', port=7860)