Spaces:
Running
Running
# ------------------ Import Libraries ------------------ | |
import dash | |
from dash import dcc, html, Input, Output, State, no_update | |
import plotly.graph_objects as go | |
import pandas as pd | |
import numpy as np | |
import cv2 | |
import base64 | |
from scipy.ndimage import gaussian_filter1d | |
import requests | |
import json | |
import tempfile | |
import os | |
from urllib.parse import urljoin | |
import subprocess | |
# ------------------ Data Download and Processing ------------------ | |
class RemoteDatasetLoader: | |
def __init__(self, repo_id: str, timeout: int = 30): | |
self.repo_id = repo_id | |
self.timeout = timeout | |
self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" | |
def _get_dataset_info(self) -> dict: | |
info_url = urljoin(self.base_url, "meta/info.json") | |
response = requests.get(info_url, timeout=self.timeout) | |
response.raise_for_status() | |
return response.json() | |
def _get_episode_info(self, episode_id: int) -> dict: | |
episodes_url = urljoin(self.base_url, "meta/episodes.jsonl") | |
response = requests.get(episodes_url, timeout=self.timeout) | |
response.raise_for_status() | |
episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()] | |
for episode in episodes: | |
if episode.get("episode_index") == episode_id: | |
return episode | |
raise ValueError(f"Episode {episode_id} not found") | |
def _is_valid_mp4(self, file_path): | |
if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100: | |
return False | |
# Use ffprobe to check if it is a valid mp4 | |
try: | |
result = subprocess.run([ | |
'ffprobe', '-v', 'error', '-select_streams', 'v:0', | |
'-show_entries', 'stream=codec_name', '-of', 'default=noprint_wrappers=1:nokey=1', file_path | |
], capture_output=True, text=True, timeout=10) | |
if result.returncode == 0 and '264' in result.stdout: | |
return True | |
except Exception as e: | |
print(f"ffprobe video check failed: {e}") | |
return False | |
def _download_video(self, video_url: str, save_path: str) -> str: | |
response = requests.get(video_url, timeout=self.timeout, stream=True) | |
response.raise_for_status() | |
# Check Content-Type | |
if 'video' not in response.headers.get('Content-Type', ''): | |
raise ValueError(f"URL {video_url} does not return video content, Content-Type: {response.headers.get('Content-Type')}") | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
with open(save_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return save_path | |
def load_episode_data(self, episode_id: int, | |
video_keys=None, | |
download_dir=None): | |
dataset_info = self._get_dataset_info() | |
self._get_episode_info(episode_id) # Check if episode exists | |
if download_dir is None: | |
download_dir = tempfile.mkdtemp(prefix="lerobot_videos_") | |
if video_keys is None: | |
video_keys = [key for key, feature in dataset_info["features"].items() | |
if feature["dtype"] == "video"] | |
video_keys = video_keys[:2] | |
video_paths = [] | |
chunks_size = dataset_info.get("chunks_size", 1000) | |
# Create repo-specific subdirectory | |
repo_name = self.repo_id.replace('/', '_') # Replace / with _ to avoid path issues | |
repo_dir = os.path.join(download_dir, repo_name) | |
os.makedirs(repo_dir, exist_ok=True) | |
for i, video_key in enumerate(video_keys): | |
video_url = self.base_url + dataset_info["video_path"].format( | |
episode_chunk=episode_id // chunks_size, | |
video_key=video_key, | |
episode_index=episode_id | |
) | |
video_filename = f"episode_{episode_id}_{video_key}.mp4" | |
local_path = os.path.join(repo_dir, video_filename) | |
# Prefer loading local valid mp4 | |
if self._is_valid_mp4(local_path): | |
print(f"Local valid video found: {local_path}") | |
video_paths.append(local_path) | |
continue | |
try: | |
downloaded_path = self._download_video(video_url, local_path) | |
video_paths.append(downloaded_path) | |
except Exception as e: | |
print(f"Failed to download video {video_key}: {e}") | |
video_paths.append(video_url) | |
data_url = self.base_url + dataset_info["data_path"].format( | |
episode_chunk=episode_id // chunks_size, | |
episode_index=episode_id | |
) | |
try: | |
df = pd.read_parquet(data_url) | |
except Exception as e: | |
print(f"Failed to load data: {e}") | |
df = pd.DataFrame() | |
return video_paths, df | |
def check_ffmpeg_available(): | |
try: | |
result = subprocess.run(['ffmpeg', '-version'], | |
capture_output=True, text=True, timeout=5) | |
return result.returncode == 0 | |
except (subprocess.TimeoutExpired, FileNotFoundError): | |
return False | |
def get_video_codec_info(video_path): | |
try: | |
result = subprocess.run([ | |
'ffprobe', '-v', 'quiet', '-print_format', 'json', | |
'-show_streams', video_path | |
], capture_output=True, text=True, timeout=10) | |
if result.returncode == 0: | |
info = json.loads(result.stdout) | |
for stream in info.get('streams', []): | |
if stream.get('codec_type') == 'video': | |
return stream.get('codec_name', 'unknown') | |
except Exception as e: | |
print(f"Failed to get video codec info: {e}") | |
return 'unknown' | |
def reencode_video_to_h264(input_path, output_path=None, quality='medium'): | |
if output_path is None: | |
base_name = os.path.splitext(input_path)[0] | |
output_path = f"{base_name}_h264.mp4" | |
quality_params = { | |
'fast': ['-preset', 'ultrafast', '-crf', '28'], | |
'medium': ['-preset', 'medium', '-crf', '23'], | |
'high': ['-preset', 'slow', '-crf', '18'] | |
} | |
params = quality_params.get(quality, quality_params['medium']) | |
try: | |
cmd = [ | |
'ffmpeg', '-i', input_path, | |
'-c:v', 'libx264', | |
'-c:a', 'aac', | |
'-movflags', '+faststart', | |
'-y', | |
] + params + [output_path] | |
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
return output_path | |
else: | |
print(f"Re-encoding failed: {result.stderr}") | |
return input_path | |
except subprocess.TimeoutExpired: | |
print("Re-encoding timeout") | |
return input_path | |
except Exception as e: | |
print(f"Re-encoding exception: {e}") | |
return input_path | |
def process_video_for_compatibility(video_path): | |
if not os.path.exists(video_path): | |
print(f"Video file does not exist: {video_path}") | |
return video_path | |
if not check_ffmpeg_available(): | |
print("ffmpeg not available, skipping re-encoding") | |
return video_path | |
codec = get_video_codec_info(video_path) | |
if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown': | |
reencoded_path = reencode_video_to_h264(video_path, quality='fast') | |
if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024: | |
return reencoded_path | |
else: | |
print("Re-encoding failed, using original file") | |
return video_path | |
else: | |
return video_path | |
def load_remote_dataset(repo_id: str, | |
episode_id: int = 0, | |
video_keys=None, | |
download_dir=None): | |
loader = RemoteDatasetLoader(repo_id) | |
video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir) | |
processed_video_paths = [] | |
for video_path in video_paths: | |
processed_path = process_video_for_compatibility(video_path) | |
processed_video_paths.append(processed_path) | |
return processed_video_paths, df | |
# ------------------ Dash Initialization ------------------ | |
app = dash.Dash(__name__, suppress_callback_exceptions=True) | |
server = app.server | |
# ------------------ Page Layout ------------------ | |
app.layout = html.Div([ | |
# Header with gradient background | |
html.Div([ | |
html.H1("Keyframe Identification", | |
style={ | |
"textAlign": "center", | |
"marginBottom": "10px", | |
"color": "white", | |
"fontSize": "2.5rem", | |
"fontWeight": "300", | |
"textShadow": "2px 2px 4px rgba(0,0,0,0.3)" | |
}), | |
html.P("Interactive Joint Analysis with Video Synchronization", | |
style={ | |
"textAlign": "center", | |
"color": "rgba(255,255,255,0.9)", | |
"fontSize": "1.1rem", | |
"marginBottom": "0" | |
}) | |
], style={ | |
"background": "linear-gradient(135deg, #667eea 0%, #764ba2 100%)", | |
"padding": "30px 20px", | |
"marginBottom": "30px", | |
"borderRadius": "0 0 15px 15px", | |
"boxShadow": "0 4px 20px rgba(0,0,0,0.1)" | |
}), | |
# Control Panel | |
html.Div([ | |
html.Div([ | |
html.Label("Repository ID:", | |
style={ | |
"fontWeight": "600", | |
"color": "#333", | |
"marginRight": "10px", | |
"fontSize": "1rem" | |
}), | |
dcc.Input( | |
id="input-repo-id", | |
type="text", | |
value="zijian2022/sortingtest", | |
style={ | |
"width": "350px", | |
"padding": "12px 15px", | |
"border": "2px solid #e1e5e9", | |
"borderRadius": "8px", | |
"fontSize": "14px", | |
"transition": "border-color 0.3s ease", | |
"outline": "none" | |
}, | |
placeholder="Enter HuggingFace dataset repository ID" | |
), | |
], style={"marginBottom": "15px"}), | |
html.Div([ | |
html.Label("Episode ID:", | |
style={ | |
"fontWeight": "600", | |
"color": "#333", | |
"marginRight": "10px", | |
"fontSize": "1rem" | |
}), | |
dcc.Input( | |
id="input-episode-id", | |
type="number", | |
value=0, | |
min=0, | |
style={ | |
"width": "120px", | |
"padding": "12px 15px", | |
"border": "2px solid #e1e5e9", | |
"borderRadius": "8px", | |
"fontSize": "14px", | |
"transition": "border-color 0.3s ease", | |
"outline": "none" | |
} | |
), | |
html.Button( | |
"Load Data", | |
id="btn-load", | |
n_clicks=0, | |
style={ | |
"marginLeft": "20px", | |
"padding": "12px 25px", | |
"backgroundColor": "#667eea", | |
"color": "white", | |
"border": "none", | |
"borderRadius": "8px", | |
"fontSize": "14px", | |
"fontWeight": "600", | |
"cursor": "pointer", | |
"transition": "all 0.3s ease", | |
"boxShadow": "0 2px 10px rgba(102, 126, 234, 0.3)" | |
} | |
), | |
]), | |
], style={ | |
"textAlign": "center", | |
"marginBottom": "40px", | |
"padding": "25px", | |
"backgroundColor": "white", | |
"borderRadius": "12px", | |
"boxShadow": "0 4px 20px rgba(0,0,0,0.08)", | |
"border": "1px solid #f0f0f0" | |
}), | |
# Loading and Data Store | |
dcc.Loading( | |
id="loading", | |
type="circle", | |
style={"margin": "20px auto"}, | |
children=dcc.Store(id="store-data") | |
), | |
# Main Content Area | |
html.Div( | |
id="main-content", | |
style={ | |
"backgroundColor": "#f8f9fa", | |
"minHeight": "400px", | |
"borderRadius": "12px", | |
"padding": "20px" | |
} | |
), | |
], style={ | |
"fontFamily": "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", | |
"backgroundColor": "#f5f7fa", | |
"minHeight": "100vh", | |
"padding": "0" | |
}) | |
# ------------------ Data Loading Callback ------------------ | |
def load_data_callback(n_clicks, repo_id, episode_id): | |
try: | |
video_paths, data_df = load_remote_dataset( | |
repo_id=repo_id, | |
episode_id=int(episode_id), | |
download_dir="./downloaded_videos" | |
) | |
if data_df is None or data_df.empty: | |
return {} | |
return { | |
"video_paths": video_paths, | |
"data_df": data_df.to_dict("records"), | |
"columns": ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"], | |
"timestamps": data_df["timestamp"].tolist() | |
} | |
except Exception as e: | |
print(f"Data loading error: {e}") | |
return {} | |
# ------------------ Main Content Rendering Callback ------------------ | |
def update_main_content(data): | |
if not data or "data_df" not in data or len(data["data_df"]) == 0: | |
return html.Div([ | |
html.Div("📊", style={"fontSize": "3rem", "marginBottom": "20px"}), | |
html.H3("No Data Available", style={"color": "#666", "marginBottom": "10px"}), | |
html.P("Please click the 'Load Data' button above to get data.", | |
style={"color": "#888", "fontSize": "1rem"}) | |
], style={ | |
"textAlign": "center", | |
"padding": "60px 20px", | |
"color": "#666" | |
}) | |
columns = data["columns"] | |
rows = [] | |
for i, joint in enumerate(columns): | |
rows.append(html.Div([ | |
# Joint Graph - Left 50% | |
html.Div([ | |
dcc.Graph(id=f"graph-{i}") | |
], style={ | |
"flex": "0 0 50%", | |
"backgroundColor": "white", | |
"borderRadius": "8px", | |
"padding": "8px", | |
"boxShadow": "0 2px 10px rgba(0,0,0,0.05)", | |
"border": "1px solid #e9ecef", | |
"marginRight": "2%" | |
}), | |
# Video Area - Right 48% | |
html.Div([ | |
html.Img(id=f"video1-{i}", style={ | |
"width": "49%", | |
"height": "180px", | |
"objectFit": "contain", | |
"display": "inline-block", | |
"borderRadius": "6px", | |
"border": "2px solid #e9ecef" | |
}), | |
html.Img(id=f"video2-{i}", style={ | |
"width": "49%", | |
"height": "180px", | |
"objectFit": "contain", | |
"display": "inline-block", | |
"borderRadius": "6px", | |
"border": "2px solid #e9ecef" | |
}) | |
], style={ | |
"flex": "0 0 48%" | |
}) | |
], style={ | |
"marginBottom": "25px", | |
"backgroundColor": "white", | |
"borderRadius": "12px", | |
"padding": "12px", | |
"boxShadow": "0 4px 15px rgba(0,0,0,0.08)", | |
"border": "1px solid #f0f0f0", | |
"display": "flex", | |
"alignItems": "flex-start", | |
"minHeight": "250px" | |
})) | |
return html.Div(rows) | |
# ------------------ Shadow and Highlight Utility Functions ------------------ | |
def find_intervals(mask): | |
intervals = [] | |
start = None | |
for i, val in enumerate(mask): | |
if val and start is None: | |
start = i | |
elif not val and start is not None: | |
intervals.append((start, i - 1)) | |
start = None | |
if start is not None: | |
intervals.append((start, len(mask) - 1)) | |
return intervals | |
def get_shadow_info(joint_name, action_df, delta_t, time_for_plot): | |
angles = action_df[joint_name].values | |
velocity = np.diff(angles) / delta_t | |
smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
vel_threshold = 0.5 | |
highlight_width = 1 | |
k = 2 | |
shadows = [] | |
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold | |
low_speed_intervals = find_intervals(low_speed_mask) | |
for start, end in low_speed_intervals: | |
if end - start + 1 <= k: | |
shadows.append({ | |
'type': 'low_speed', | |
'start_time': time_for_plot[start], | |
'end_time': time_for_plot[end], | |
'start_idx': start, | |
'end_idx': end | |
}) | |
max_idx = np.argmax(smoothed_angle) | |
s_max = max(0, max_idx - highlight_width) | |
e_max = min(len(time_for_plot) - 1, max_idx + highlight_width) | |
shadows.append({ | |
'type': 'max_value', | |
'start_time': time_for_plot[s_max], | |
'end_time': time_for_plot[e_max], | |
'start_idx': s_max, | |
'end_idx': e_max | |
}) | |
min_idx = np.argmin(smoothed_angle) | |
s_min = max(0, min_idx - highlight_width) | |
e_min = min(len(time_for_plot) - 1, min_idx + highlight_width) | |
shadows.append({ | |
'type': 'min_value', | |
'start_time': time_for_plot[s_min], | |
'end_time': time_for_plot[e_min], | |
'start_idx': s_min, | |
'end_idx': e_min | |
}) | |
return shadows | |
def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all_shadows): | |
angles = action_df[joint_name].values | |
velocity = np.diff(angles) / delta_t | |
smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
shapes = [] | |
current_shadows = all_shadows[joint_name] | |
for shadow in current_shadows: | |
shapes.append({ | |
"type": "rect", | |
"xref": "x", | |
"yref": "paper", | |
"x0": shadow['start_time'], | |
"x1": shadow['end_time'], | |
"y0": 0, | |
"y1": 1, | |
"fillcolor": "#ef4444", # Fixed red | |
"opacity": 0.4, | |
"line": {"width": 0} | |
}) | |
return { | |
"data": [ | |
go.Scatter( | |
x=time_for_plot, | |
y=smoothed_angle, | |
name="Joint Angle", | |
line=dict(color='#f59e0b', width=2), | |
hovertemplate='<b>Time:</b> %{x:.2f}s<br><b>Angle:</b> %{y:.2f}°<extra></extra>' | |
) | |
], | |
"layout": go.Layout( | |
title={ | |
'text': joint_name.replace('_', ' ').title(), | |
'font': {'size': 16, 'color': '#374151'} | |
}, | |
xaxis={ | |
"title": "Time (seconds)", | |
"titlefont": {"color": "#6b7280"}, | |
"tickfont": {"color": "#6b7280"}, | |
"gridcolor": "#f3f4f6", | |
"zerolinecolor": "#e5e7eb" | |
}, | |
yaxis={ | |
"title": "Angle (degrees)", | |
"titlefont": {"color": "#6b7280"}, | |
"tickfont": {"color": "#6b7280"}, | |
"gridcolor": "#f3f4f6", | |
"zerolinecolor": "#e5e7eb" | |
}, | |
shapes=shapes, | |
hovermode="x unified", | |
height=220, | |
margin=dict(t=30, b=30, l=50, r=30), | |
showlegend=False, | |
plot_bgcolor='white', | |
paper_bgcolor='white', | |
font={'family': "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif"}, | |
hoverlabel=dict( | |
bgcolor="white", | |
font_size=12, | |
font_family="'Segoe UI', Tahoma, Geneva, Verdana, sans-serif" | |
) | |
) | |
} | |
# ------------------ Chart Update Callback ------------------ | |
def update_all_graphs(data): | |
if not data or "data_df" not in data or len(data["data_df"]) == 0: | |
return [no_update] * 6 | |
columns = data["columns"] | |
df = pd.DataFrame.from_records(data["data_df"]) | |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns) | |
timestamps = df["timestamp"].values | |
delta_t = np.diff(timestamps) | |
time_for_plot = timestamps[1:] | |
all_shadows = {} | |
for joint in columns: | |
all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot) | |
# Generate all charts, no highlight logic | |
return [ | |
generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows) | |
for i, joint in enumerate(columns) | |
] | |
# ------------------ Video Frame Extraction Function ------------------ | |
def get_video_frame(video_path, time_in_seconds): | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"❌ Cannot open video: {video_path}") | |
return None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
if fps <= 0: | |
cap.release() | |
return None | |
frame_num = int(time_in_seconds * fps) | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
success, frame = cap.read() | |
cap.release() | |
if success and frame is not None: | |
height, width = frame.shape[:2] | |
if width > 640: | |
new_width = 640 | |
new_height = int(height * (new_width / width)) | |
frame = cv2.resize(frame, (new_width, new_height)) | |
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85] | |
_, buffer = cv2.imencode('.jpg', frame, encode_param) | |
encoded = base64.b64encode(buffer).decode('utf-8') | |
return f"data:image/jpeg;base64,{encoded}" | |
else: | |
return None | |
except Exception as e: | |
print(f"❌ Exception extracting video frame: {e}") | |
return None | |
# ------------------ Video Frame Callback ------------------ | |
for i in range(6): | |
def update_video_frames(data, hover_data, idx=i): | |
if not data or "data_df" not in data or len(data["data_df"]) == 0: | |
return no_update, no_update | |
columns = data["columns"] | |
df = pd.DataFrame.from_records(data["data_df"]) | |
timestamps = df["timestamp"].values | |
time_for_plot = timestamps[1:] | |
video_paths = data["video_paths"] | |
# Determine the time point to display | |
display_time = 0.0 # Default to start time | |
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0: | |
# If there is hover data, use hover time | |
display_time = float(hover_data["points"][0]["x"]) | |
elif len(time_for_plot) > 0: | |
# If no hover data, use the start time of the timeline | |
display_time = time_for_plot[0] | |
try: | |
frame1 = get_video_frame(video_paths[0], display_time) | |
frame2 = get_video_frame(video_paths[1], display_time) | |
if frame1 and frame2: | |
return frame1, frame2 | |
else: | |
return no_update, no_update | |
except Exception as e: | |
print(f"update_video_frames callback error: {e}") | |
return no_update, no_update | |
# ------------------ Start Application ------------------ | |
if __name__ == "__main__": | |
app.run(debug=True, host='0.0.0.0', port=7860) |