keyframe / app.py
zijian2022's picture
Update app.py
1923cfd verified
# ------------------ Import Libraries ------------------
import dash
from dash import dcc, html, Input, Output, State, no_update
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import cv2
import base64
from scipy.ndimage import gaussian_filter1d
import requests
import json
import tempfile
import os
from urllib.parse import urljoin
import subprocess
# ------------------ Data Download and Processing ------------------
class RemoteDatasetLoader:
def __init__(self, repo_id: str, timeout: int = 30):
self.repo_id = repo_id
self.timeout = timeout
self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
def _get_dataset_info(self) -> dict:
info_url = urljoin(self.base_url, "meta/info.json")
response = requests.get(info_url, timeout=self.timeout)
response.raise_for_status()
return response.json()
def _get_episode_info(self, episode_id: int) -> dict:
episodes_url = urljoin(self.base_url, "meta/episodes.jsonl")
response = requests.get(episodes_url, timeout=self.timeout)
response.raise_for_status()
episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()]
for episode in episodes:
if episode.get("episode_index") == episode_id:
return episode
raise ValueError(f"Episode {episode_id} not found")
def _is_valid_mp4(self, file_path):
if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100:
return False
# Use ffprobe to check if it is a valid mp4
try:
result = subprocess.run([
'ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=codec_name', '-of', 'default=noprint_wrappers=1:nokey=1', file_path
], capture_output=True, text=True, timeout=10)
if result.returncode == 0 and '264' in result.stdout:
return True
except Exception as e:
print(f"ffprobe video check failed: {e}")
return False
def _download_video(self, video_url: str, save_path: str) -> str:
response = requests.get(video_url, timeout=self.timeout, stream=True)
response.raise_for_status()
# Check Content-Type
if 'video' not in response.headers.get('Content-Type', ''):
raise ValueError(f"URL {video_url} does not return video content, Content-Type: {response.headers.get('Content-Type')}")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return save_path
def load_episode_data(self, episode_id: int,
video_keys=None,
download_dir=None):
dataset_info = self._get_dataset_info()
self._get_episode_info(episode_id) # Check if episode exists
if download_dir is None:
download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
if video_keys is None:
video_keys = [key for key, feature in dataset_info["features"].items()
if feature["dtype"] == "video"]
video_keys = video_keys[:2]
video_paths = []
chunks_size = dataset_info.get("chunks_size", 1000)
# Create repo-specific subdirectory
repo_name = self.repo_id.replace('/', '_') # Replace / with _ to avoid path issues
repo_dir = os.path.join(download_dir, repo_name)
os.makedirs(repo_dir, exist_ok=True)
for i, video_key in enumerate(video_keys):
video_url = self.base_url + dataset_info["video_path"].format(
episode_chunk=episode_id // chunks_size,
video_key=video_key,
episode_index=episode_id
)
video_filename = f"episode_{episode_id}_{video_key}.mp4"
local_path = os.path.join(repo_dir, video_filename)
# Prefer loading local valid mp4
if self._is_valid_mp4(local_path):
print(f"Local valid video found: {local_path}")
video_paths.append(local_path)
continue
try:
downloaded_path = self._download_video(video_url, local_path)
video_paths.append(downloaded_path)
except Exception as e:
print(f"Failed to download video {video_key}: {e}")
video_paths.append(video_url)
data_url = self.base_url + dataset_info["data_path"].format(
episode_chunk=episode_id // chunks_size,
episode_index=episode_id
)
try:
df = pd.read_parquet(data_url)
except Exception as e:
print(f"Failed to load data: {e}")
df = pd.DataFrame()
return video_paths, df
def check_ffmpeg_available():
try:
result = subprocess.run(['ffmpeg', '-version'],
capture_output=True, text=True, timeout=5)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
return False
def get_video_codec_info(video_path):
try:
result = subprocess.run([
'ffprobe', '-v', 'quiet', '-print_format', 'json',
'-show_streams', video_path
], capture_output=True, text=True, timeout=10)
if result.returncode == 0:
info = json.loads(result.stdout)
for stream in info.get('streams', []):
if stream.get('codec_type') == 'video':
return stream.get('codec_name', 'unknown')
except Exception as e:
print(f"Failed to get video codec info: {e}")
return 'unknown'
def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
if output_path is None:
base_name = os.path.splitext(input_path)[0]
output_path = f"{base_name}_h264.mp4"
quality_params = {
'fast': ['-preset', 'ultrafast', '-crf', '28'],
'medium': ['-preset', 'medium', '-crf', '23'],
'high': ['-preset', 'slow', '-crf', '18']
}
params = quality_params.get(quality, quality_params['medium'])
try:
cmd = [
'ffmpeg', '-i', input_path,
'-c:v', 'libx264',
'-c:a', 'aac',
'-movflags', '+faststart',
'-y',
] + params + [output_path]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode == 0:
return output_path
else:
print(f"Re-encoding failed: {result.stderr}")
return input_path
except subprocess.TimeoutExpired:
print("Re-encoding timeout")
return input_path
except Exception as e:
print(f"Re-encoding exception: {e}")
return input_path
def process_video_for_compatibility(video_path):
if not os.path.exists(video_path):
print(f"Video file does not exist: {video_path}")
return video_path
if not check_ffmpeg_available():
print("ffmpeg not available, skipping re-encoding")
return video_path
codec = get_video_codec_info(video_path)
if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
reencoded_path = reencode_video_to_h264(video_path, quality='fast')
if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
return reencoded_path
else:
print("Re-encoding failed, using original file")
return video_path
else:
return video_path
def load_remote_dataset(repo_id: str,
episode_id: int = 0,
video_keys=None,
download_dir=None):
loader = RemoteDatasetLoader(repo_id)
video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir)
processed_video_paths = []
for video_path in video_paths:
processed_path = process_video_for_compatibility(video_path)
processed_video_paths.append(processed_path)
return processed_video_paths, df
# ------------------ Dash Initialization ------------------
app = dash.Dash(__name__, suppress_callback_exceptions=True)
server = app.server
# ------------------ Page Layout ------------------
app.layout = html.Div([
# Header with gradient background
html.Div([
html.H1("Keyframe Identification",
style={
"textAlign": "center",
"marginBottom": "10px",
"color": "white",
"fontSize": "2.5rem",
"fontWeight": "300",
"textShadow": "2px 2px 4px rgba(0,0,0,0.3)"
}),
html.P("Interactive Joint Analysis with Video Synchronization",
style={
"textAlign": "center",
"color": "rgba(255,255,255,0.9)",
"fontSize": "1.1rem",
"marginBottom": "0"
})
], style={
"background": "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
"padding": "30px 20px",
"marginBottom": "30px",
"borderRadius": "0 0 15px 15px",
"boxShadow": "0 4px 20px rgba(0,0,0,0.1)"
}),
# Control Panel
html.Div([
html.Div([
html.Label("Repository ID:",
style={
"fontWeight": "600",
"color": "#333",
"marginRight": "10px",
"fontSize": "1rem"
}),
dcc.Input(
id="input-repo-id",
type="text",
value="zijian2022/sortingtest",
style={
"width": "350px",
"padding": "12px 15px",
"border": "2px solid #e1e5e9",
"borderRadius": "8px",
"fontSize": "14px",
"transition": "border-color 0.3s ease",
"outline": "none"
},
placeholder="Enter HuggingFace dataset repository ID"
),
], style={"marginBottom": "15px"}),
html.Div([
html.Label("Episode ID:",
style={
"fontWeight": "600",
"color": "#333",
"marginRight": "10px",
"fontSize": "1rem"
}),
dcc.Input(
id="input-episode-id",
type="number",
value=0,
min=0,
style={
"width": "120px",
"padding": "12px 15px",
"border": "2px solid #e1e5e9",
"borderRadius": "8px",
"fontSize": "14px",
"transition": "border-color 0.3s ease",
"outline": "none"
}
),
html.Button(
"Load Data",
id="btn-load",
n_clicks=0,
style={
"marginLeft": "20px",
"padding": "12px 25px",
"backgroundColor": "#667eea",
"color": "white",
"border": "none",
"borderRadius": "8px",
"fontSize": "14px",
"fontWeight": "600",
"cursor": "pointer",
"transition": "all 0.3s ease",
"boxShadow": "0 2px 10px rgba(102, 126, 234, 0.3)"
}
),
]),
], style={
"textAlign": "center",
"marginBottom": "40px",
"padding": "25px",
"backgroundColor": "white",
"borderRadius": "12px",
"boxShadow": "0 4px 20px rgba(0,0,0,0.08)",
"border": "1px solid #f0f0f0"
}),
# Loading and Data Store
dcc.Loading(
id="loading",
type="circle",
style={"margin": "20px auto"},
children=dcc.Store(id="store-data")
),
# Main Content Area
html.Div(
id="main-content",
style={
"backgroundColor": "#f8f9fa",
"minHeight": "400px",
"borderRadius": "12px",
"padding": "20px"
}
),
], style={
"fontFamily": "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif",
"backgroundColor": "#f5f7fa",
"minHeight": "100vh",
"padding": "0"
})
# ------------------ Data Loading Callback ------------------
@app.callback(
Output("store-data", "data"),
Input("btn-load", "n_clicks"),
State("input-repo-id", "value"),
State("input-episode-id", "value"),
prevent_initial_call=True
)
def load_data_callback(n_clicks, repo_id, episode_id):
try:
video_paths, data_df = load_remote_dataset(
repo_id=repo_id,
episode_id=int(episode_id),
download_dir="./downloaded_videos"
)
if data_df is None or data_df.empty:
return {}
return {
"video_paths": video_paths,
"data_df": data_df.to_dict("records"),
"columns": ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"],
"timestamps": data_df["timestamp"].tolist()
}
except Exception as e:
print(f"Data loading error: {e}")
return {}
# ------------------ Main Content Rendering Callback ------------------
@app.callback(
Output("main-content", "children"),
Input("store-data", "data")
)
def update_main_content(data):
if not data or "data_df" not in data or len(data["data_df"]) == 0:
return html.Div([
html.Div("📊", style={"fontSize": "3rem", "marginBottom": "20px"}),
html.H3("No Data Available", style={"color": "#666", "marginBottom": "10px"}),
html.P("Please click the 'Load Data' button above to get data.",
style={"color": "#888", "fontSize": "1rem"})
], style={
"textAlign": "center",
"padding": "60px 20px",
"color": "#666"
})
columns = data["columns"]
rows = []
for i, joint in enumerate(columns):
rows.append(html.Div([
# Joint Graph - Left 50%
html.Div([
dcc.Graph(id=f"graph-{i}")
], style={
"flex": "0 0 50%",
"backgroundColor": "white",
"borderRadius": "8px",
"padding": "8px",
"boxShadow": "0 2px 10px rgba(0,0,0,0.05)",
"border": "1px solid #e9ecef",
"marginRight": "2%"
}),
# Video Area - Right 48%
html.Div([
html.Img(id=f"video1-{i}", style={
"width": "49%",
"height": "180px",
"objectFit": "contain",
"display": "inline-block",
"borderRadius": "6px",
"border": "2px solid #e9ecef"
}),
html.Img(id=f"video2-{i}", style={
"width": "49%",
"height": "180px",
"objectFit": "contain",
"display": "inline-block",
"borderRadius": "6px",
"border": "2px solid #e9ecef"
})
], style={
"flex": "0 0 48%"
})
], style={
"marginBottom": "25px",
"backgroundColor": "white",
"borderRadius": "12px",
"padding": "12px",
"boxShadow": "0 4px 15px rgba(0,0,0,0.08)",
"border": "1px solid #f0f0f0",
"display": "flex",
"alignItems": "flex-start",
"minHeight": "250px"
}))
return html.Div(rows)
# ------------------ Shadow and Highlight Utility Functions ------------------
def find_intervals(mask):
intervals = []
start = None
for i, val in enumerate(mask):
if val and start is None:
start = i
elif not val and start is not None:
intervals.append((start, i - 1))
start = None
if start is not None:
intervals.append((start, len(mask) - 1))
return intervals
def get_shadow_info(joint_name, action_df, delta_t, time_for_plot):
angles = action_df[joint_name].values
velocity = np.diff(angles) / delta_t
smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
vel_threshold = 0.5
highlight_width = 1
k = 2
shadows = []
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
low_speed_intervals = find_intervals(low_speed_mask)
for start, end in low_speed_intervals:
if end - start + 1 <= k:
shadows.append({
'type': 'low_speed',
'start_time': time_for_plot[start],
'end_time': time_for_plot[end],
'start_idx': start,
'end_idx': end
})
max_idx = np.argmax(smoothed_angle)
s_max = max(0, max_idx - highlight_width)
e_max = min(len(time_for_plot) - 1, max_idx + highlight_width)
shadows.append({
'type': 'max_value',
'start_time': time_for_plot[s_max],
'end_time': time_for_plot[e_max],
'start_idx': s_max,
'end_idx': e_max
})
min_idx = np.argmin(smoothed_angle)
s_min = max(0, min_idx - highlight_width)
e_min = min(len(time_for_plot) - 1, min_idx + highlight_width)
shadows.append({
'type': 'min_value',
'start_time': time_for_plot[s_min],
'end_time': time_for_plot[e_min],
'start_idx': s_min,
'end_idx': e_min
})
return shadows
def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all_shadows):
angles = action_df[joint_name].values
velocity = np.diff(angles) / delta_t
smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
shapes = []
current_shadows = all_shadows[joint_name]
for shadow in current_shadows:
shapes.append({
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": shadow['start_time'],
"x1": shadow['end_time'],
"y0": 0,
"y1": 1,
"fillcolor": "#ef4444", # Fixed red
"opacity": 0.4,
"line": {"width": 0}
})
return {
"data": [
go.Scatter(
x=time_for_plot,
y=smoothed_angle,
name="Joint Angle",
line=dict(color='#f59e0b', width=2),
hovertemplate='<b>Time:</b> %{x:.2f}s<br><b>Angle:</b> %{y:.2f}°<extra></extra>'
)
],
"layout": go.Layout(
title={
'text': joint_name.replace('_', ' ').title(),
'font': {'size': 16, 'color': '#374151'}
},
xaxis={
"title": "Time (seconds)",
"titlefont": {"color": "#6b7280"},
"tickfont": {"color": "#6b7280"},
"gridcolor": "#f3f4f6",
"zerolinecolor": "#e5e7eb"
},
yaxis={
"title": "Angle (degrees)",
"titlefont": {"color": "#6b7280"},
"tickfont": {"color": "#6b7280"},
"gridcolor": "#f3f4f6",
"zerolinecolor": "#e5e7eb"
},
shapes=shapes,
hovermode="x unified",
height=220,
margin=dict(t=30, b=30, l=50, r=30),
showlegend=False,
plot_bgcolor='white',
paper_bgcolor='white',
font={'family': "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif"},
hoverlabel=dict(
bgcolor="white",
font_size=12,
font_family="'Segoe UI', Tahoma, Geneva, Verdana, sans-serif"
)
)
}
# ------------------ Chart Update Callback ------------------
@app.callback(
[Output(f"graph-{i}", "figure") for i in range(6)],
[Input("store-data", "data")],
prevent_initial_call=True
)
def update_all_graphs(data):
if not data or "data_df" not in data or len(data["data_df"]) == 0:
return [no_update] * 6
columns = data["columns"]
df = pd.DataFrame.from_records(data["data_df"])
action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
timestamps = df["timestamp"].values
delta_t = np.diff(timestamps)
time_for_plot = timestamps[1:]
all_shadows = {}
for joint in columns:
all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot)
# Generate all charts, no highlight logic
return [
generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows)
for i, joint in enumerate(columns)
]
# ------------------ Video Frame Extraction Function ------------------
def get_video_frame(video_path, time_in_seconds):
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"❌ Cannot open video: {video_path}")
return None
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
cap.release()
return None
frame_num = int(time_in_seconds * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
success, frame = cap.read()
cap.release()
if success and frame is not None:
height, width = frame.shape[:2]
if width > 640:
new_width = 640
new_height = int(height * (new_width / width))
frame = cv2.resize(frame, (new_width, new_height))
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85]
_, buffer = cv2.imencode('.jpg', frame, encode_param)
encoded = base64.b64encode(buffer).decode('utf-8')
return f"data:image/jpeg;base64,{encoded}"
else:
return None
except Exception as e:
print(f"❌ Exception extracting video frame: {e}")
return None
# ------------------ Video Frame Callback ------------------
for i in range(6):
@app.callback(
Output(f"video1-{i}", "src"),
Output(f"video2-{i}", "src"),
Input("store-data", "data"),
Input(f"graph-{i}", "hoverData"),
prevent_initial_call=True
)
def update_video_frames(data, hover_data, idx=i):
if not data or "data_df" not in data or len(data["data_df"]) == 0:
return no_update, no_update
columns = data["columns"]
df = pd.DataFrame.from_records(data["data_df"])
timestamps = df["timestamp"].values
time_for_plot = timestamps[1:]
video_paths = data["video_paths"]
# Determine the time point to display
display_time = 0.0 # Default to start time
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
# If there is hover data, use hover time
display_time = float(hover_data["points"][0]["x"])
elif len(time_for_plot) > 0:
# If no hover data, use the start time of the timeline
display_time = time_for_plot[0]
try:
frame1 = get_video_frame(video_paths[0], display_time)
frame2 = get_video_frame(video_paths[1], display_time)
if frame1 and frame2:
return frame1, frame2
else:
return no_update, no_update
except Exception as e:
print(f"update_video_frames callback error: {e}")
return no_update, no_update
# ------------------ Start Application ------------------
if __name__ == "__main__":
app.run(debug=True, host='0.0.0.0', port=7860)