Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# ------------------
|
2 |
import dash
|
3 |
from dash import dcc, html, Input, Output, State, no_update
|
4 |
import plotly.graph_objects as go
|
@@ -14,7 +14,7 @@ import os
|
|
14 |
from urllib.parse import urljoin
|
15 |
import subprocess
|
16 |
|
17 |
-
# ------------------
|
18 |
class RemoteDatasetLoader:
|
19 |
def __init__(self, repo_id: str, timeout: int = 30):
|
20 |
self.repo_id = repo_id
|
@@ -40,7 +40,7 @@ class RemoteDatasetLoader:
|
|
40 |
def _is_valid_mp4(self, file_path):
|
41 |
if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100:
|
42 |
return False
|
43 |
-
#
|
44 |
try:
|
45 |
result = subprocess.run([
|
46 |
'ffprobe', '-v', 'error', '-select_streams', 'v:0',
|
@@ -55,9 +55,9 @@ class RemoteDatasetLoader:
|
|
55 |
def _download_video(self, video_url: str, save_path: str) -> str:
|
56 |
response = requests.get(video_url, timeout=self.timeout, stream=True)
|
57 |
response.raise_for_status()
|
58 |
-
#
|
59 |
if 'video' not in response.headers.get('Content-Type', ''):
|
60 |
-
raise ValueError(f"URL {video_url}
|
61 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
62 |
with open(save_path, 'wb') as f:
|
63 |
for chunk in response.iter_content(chunk_size=8192):
|
@@ -68,7 +68,7 @@ class RemoteDatasetLoader:
|
|
68 |
video_keys=None,
|
69 |
download_dir=None):
|
70 |
dataset_info = self._get_dataset_info()
|
71 |
-
self._get_episode_info(episode_id) #
|
72 |
|
73 |
if download_dir is None:
|
74 |
download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
|
@@ -81,8 +81,8 @@ class RemoteDatasetLoader:
|
|
81 |
video_paths = []
|
82 |
chunks_size = dataset_info.get("chunks_size", 1000)
|
83 |
|
84 |
-
#
|
85 |
-
repo_name = self.repo_id.replace('/', '_') #
|
86 |
repo_dir = os.path.join(download_dir, repo_name)
|
87 |
os.makedirs(repo_dir, exist_ok=True)
|
88 |
|
@@ -94,7 +94,7 @@ class RemoteDatasetLoader:
|
|
94 |
)
|
95 |
video_filename = f"episode_{episode_id}_{video_key}.mp4"
|
96 |
local_path = os.path.join(repo_dir, video_filename)
|
97 |
-
#
|
98 |
if self._is_valid_mp4(local_path):
|
99 |
print(f"Local valid video found: {local_path}")
|
100 |
video_paths.append(local_path)
|
@@ -138,7 +138,7 @@ def get_video_codec_info(video_path):
|
|
138 |
if stream.get('codec_type') == 'video':
|
139 |
return stream.get('codec_name', 'unknown')
|
140 |
except Exception as e:
|
141 |
-
print(f"
|
142 |
return 'unknown'
|
143 |
|
144 |
def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
|
@@ -163,21 +163,21 @@ def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
|
|
163 |
if result.returncode == 0:
|
164 |
return output_path
|
165 |
else:
|
166 |
-
print(f"
|
167 |
return input_path
|
168 |
except subprocess.TimeoutExpired:
|
169 |
-
print("
|
170 |
return input_path
|
171 |
except Exception as e:
|
172 |
-
print(f"
|
173 |
return input_path
|
174 |
|
175 |
def process_video_for_compatibility(video_path):
|
176 |
if not os.path.exists(video_path):
|
177 |
-
print(f"
|
178 |
return video_path
|
179 |
if not check_ffmpeg_available():
|
180 |
-
print("ffmpeg
|
181 |
return video_path
|
182 |
codec = get_video_codec_info(video_path)
|
183 |
if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
|
@@ -185,7 +185,7 @@ def process_video_for_compatibility(video_path):
|
|
185 |
if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
|
186 |
return reencoded_path
|
187 |
else:
|
188 |
-
print("
|
189 |
return video_path
|
190 |
else:
|
191 |
return video_path
|
@@ -202,15 +202,15 @@ def load_remote_dataset(repo_id: str,
|
|
202 |
processed_video_paths.append(processed_path)
|
203 |
return processed_video_paths, df
|
204 |
|
205 |
-
# ------------------ Dash
|
206 |
app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
207 |
server = app.server
|
208 |
|
209 |
-
# ------------------
|
210 |
app.layout = html.Div([
|
211 |
# Header with gradient background
|
212 |
html.Div([
|
213 |
-
html.H1("
|
214 |
style={
|
215 |
"textAlign": "center",
|
216 |
"marginBottom": "10px",
|
@@ -340,7 +340,7 @@ app.layout = html.Div([
|
|
340 |
"padding": "0"
|
341 |
})
|
342 |
|
343 |
-
# ------------------
|
344 |
@app.callback(
|
345 |
Output("store-data", "data"),
|
346 |
Input("btn-load", "n_clicks"),
|
@@ -367,7 +367,7 @@ def load_data_callback(n_clicks, repo_id, episode_id):
|
|
367 |
print(f"Data loading error: {e}")
|
368 |
return {}
|
369 |
|
370 |
-
# ------------------
|
371 |
@app.callback(
|
372 |
Output("main-content", "children"),
|
373 |
Input("store-data", "data")
|
@@ -389,7 +389,7 @@ def update_main_content(data):
|
|
389 |
rows = []
|
390 |
for i, joint in enumerate(columns):
|
391 |
rows.append(html.Div([
|
392 |
-
#
|
393 |
html.Div([
|
394 |
dcc.Graph(id=f"graph-{i}")
|
395 |
], style={
|
@@ -401,7 +401,7 @@ def update_main_content(data):
|
|
401 |
"border": "1px solid #e9ecef",
|
402 |
"marginRight": "2%"
|
403 |
}),
|
404 |
-
#
|
405 |
html.Div([
|
406 |
html.Img(id=f"video1-{i}", style={
|
407 |
"width": "49%",
|
@@ -435,7 +435,7 @@ def update_main_content(data):
|
|
435 |
}))
|
436 |
return html.Div(rows)
|
437 |
|
438 |
-
# ------------------
|
439 |
def find_intervals(mask):
|
440 |
intervals = []
|
441 |
start = None
|
@@ -509,7 +509,7 @@ def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all
|
|
509 |
"x1": shadow['end_time'],
|
510 |
"y0": 0,
|
511 |
"y1": 1,
|
512 |
-
"fillcolor": "#ef4444", #
|
513 |
"opacity": 0.4,
|
514 |
"line": {"width": 0}
|
515 |
})
|
@@ -558,7 +558,7 @@ def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all
|
|
558 |
)
|
559 |
}
|
560 |
|
561 |
-
# ------------------
|
562 |
@app.callback(
|
563 |
[Output(f"graph-{i}", "figure") for i in range(6)],
|
564 |
[Input("store-data", "data")],
|
@@ -578,18 +578,18 @@ def update_all_graphs(data):
|
|
578 |
for joint in columns:
|
579 |
all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot)
|
580 |
|
581 |
-
#
|
582 |
return [
|
583 |
generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows)
|
584 |
for i, joint in enumerate(columns)
|
585 |
]
|
586 |
|
587 |
-
# ------------------
|
588 |
def get_video_frame(video_path, time_in_seconds):
|
589 |
try:
|
590 |
cap = cv2.VideoCapture(video_path)
|
591 |
if not cap.isOpened():
|
592 |
-
print(f"❌
|
593 |
return None
|
594 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
595 |
if fps <= 0:
|
@@ -612,10 +612,10 @@ def get_video_frame(video_path, time_in_seconds):
|
|
612 |
else:
|
613 |
return None
|
614 |
except Exception as e:
|
615 |
-
print(f"❌
|
616 |
return None
|
617 |
|
618 |
-
# ------------------
|
619 |
for i in range(6):
|
620 |
@app.callback(
|
621 |
Output(f"video1-{i}", "src"),
|
@@ -633,13 +633,13 @@ for i in range(6):
|
|
633 |
time_for_plot = timestamps[1:]
|
634 |
video_paths = data["video_paths"]
|
635 |
|
636 |
-
#
|
637 |
-
display_time = 0.0 #
|
638 |
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
|
639 |
-
#
|
640 |
display_time = float(hover_data["points"][0]["x"])
|
641 |
elif len(time_for_plot) > 0:
|
642 |
-
#
|
643 |
display_time = time_for_plot[0]
|
644 |
|
645 |
try:
|
@@ -653,6 +653,6 @@ for i in range(6):
|
|
653 |
print(f"update_video_frames callback error: {e}")
|
654 |
return no_update, no_update
|
655 |
|
656 |
-
# ------------------
|
657 |
if __name__ == "__main__":
|
658 |
app.run(debug=True, host='0.0.0.0', port=7860)
|
|
|
1 |
+
# ------------------ Import Libraries ------------------
|
2 |
import dash
|
3 |
from dash import dcc, html, Input, Output, State, no_update
|
4 |
import plotly.graph_objects as go
|
|
|
14 |
from urllib.parse import urljoin
|
15 |
import subprocess
|
16 |
|
17 |
+
# ------------------ Data Download and Processing ------------------
|
18 |
class RemoteDatasetLoader:
|
19 |
def __init__(self, repo_id: str, timeout: int = 30):
|
20 |
self.repo_id = repo_id
|
|
|
40 |
def _is_valid_mp4(self, file_path):
|
41 |
if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100:
|
42 |
return False
|
43 |
+
# Use ffprobe to check if it is a valid mp4
|
44 |
try:
|
45 |
result = subprocess.run([
|
46 |
'ffprobe', '-v', 'error', '-select_streams', 'v:0',
|
|
|
55 |
def _download_video(self, video_url: str, save_path: str) -> str:
|
56 |
response = requests.get(video_url, timeout=self.timeout, stream=True)
|
57 |
response.raise_for_status()
|
58 |
+
# Check Content-Type
|
59 |
if 'video' not in response.headers.get('Content-Type', ''):
|
60 |
+
raise ValueError(f"URL {video_url} does not return video content, Content-Type: {response.headers.get('Content-Type')}")
|
61 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
62 |
with open(save_path, 'wb') as f:
|
63 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
68 |
video_keys=None,
|
69 |
download_dir=None):
|
70 |
dataset_info = self._get_dataset_info()
|
71 |
+
self._get_episode_info(episode_id) # Check if episode exists
|
72 |
|
73 |
if download_dir is None:
|
74 |
download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
|
|
|
81 |
video_paths = []
|
82 |
chunks_size = dataset_info.get("chunks_size", 1000)
|
83 |
|
84 |
+
# Create repo-specific subdirectory
|
85 |
+
repo_name = self.repo_id.replace('/', '_') # Replace / with _ to avoid path issues
|
86 |
repo_dir = os.path.join(download_dir, repo_name)
|
87 |
os.makedirs(repo_dir, exist_ok=True)
|
88 |
|
|
|
94 |
)
|
95 |
video_filename = f"episode_{episode_id}_{video_key}.mp4"
|
96 |
local_path = os.path.join(repo_dir, video_filename)
|
97 |
+
# Prefer loading local valid mp4
|
98 |
if self._is_valid_mp4(local_path):
|
99 |
print(f"Local valid video found: {local_path}")
|
100 |
video_paths.append(local_path)
|
|
|
138 |
if stream.get('codec_type') == 'video':
|
139 |
return stream.get('codec_name', 'unknown')
|
140 |
except Exception as e:
|
141 |
+
print(f"Failed to get video codec info: {e}")
|
142 |
return 'unknown'
|
143 |
|
144 |
def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
|
|
|
163 |
if result.returncode == 0:
|
164 |
return output_path
|
165 |
else:
|
166 |
+
print(f"Re-encoding failed: {result.stderr}")
|
167 |
return input_path
|
168 |
except subprocess.TimeoutExpired:
|
169 |
+
print("Re-encoding timeout")
|
170 |
return input_path
|
171 |
except Exception as e:
|
172 |
+
print(f"Re-encoding exception: {e}")
|
173 |
return input_path
|
174 |
|
175 |
def process_video_for_compatibility(video_path):
|
176 |
if not os.path.exists(video_path):
|
177 |
+
print(f"Video file does not exist: {video_path}")
|
178 |
return video_path
|
179 |
if not check_ffmpeg_available():
|
180 |
+
print("ffmpeg not available, skipping re-encoding")
|
181 |
return video_path
|
182 |
codec = get_video_codec_info(video_path)
|
183 |
if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
|
|
|
185 |
if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
|
186 |
return reencoded_path
|
187 |
else:
|
188 |
+
print("Re-encoding failed, using original file")
|
189 |
return video_path
|
190 |
else:
|
191 |
return video_path
|
|
|
202 |
processed_video_paths.append(processed_path)
|
203 |
return processed_video_paths, df
|
204 |
|
205 |
+
# ------------------ Dash Initialization ------------------
|
206 |
app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
207 |
server = app.server
|
208 |
|
209 |
+
# ------------------ Page Layout ------------------
|
210 |
app.layout = html.Div([
|
211 |
# Header with gradient background
|
212 |
html.Div([
|
213 |
+
html.H1("Keyframe Identification",
|
214 |
style={
|
215 |
"textAlign": "center",
|
216 |
"marginBottom": "10px",
|
|
|
340 |
"padding": "0"
|
341 |
})
|
342 |
|
343 |
+
# ------------------ Data Loading Callback ------------------
|
344 |
@app.callback(
|
345 |
Output("store-data", "data"),
|
346 |
Input("btn-load", "n_clicks"),
|
|
|
367 |
print(f"Data loading error: {e}")
|
368 |
return {}
|
369 |
|
370 |
+
# ------------------ Main Content Rendering Callback ------------------
|
371 |
@app.callback(
|
372 |
Output("main-content", "children"),
|
373 |
Input("store-data", "data")
|
|
|
389 |
rows = []
|
390 |
for i, joint in enumerate(columns):
|
391 |
rows.append(html.Div([
|
392 |
+
# Joint Graph - Left 50%
|
393 |
html.Div([
|
394 |
dcc.Graph(id=f"graph-{i}")
|
395 |
], style={
|
|
|
401 |
"border": "1px solid #e9ecef",
|
402 |
"marginRight": "2%"
|
403 |
}),
|
404 |
+
# Video Area - Right 48%
|
405 |
html.Div([
|
406 |
html.Img(id=f"video1-{i}", style={
|
407 |
"width": "49%",
|
|
|
435 |
}))
|
436 |
return html.Div(rows)
|
437 |
|
438 |
+
# ------------------ Shadow and Highlight Utility Functions ------------------
|
439 |
def find_intervals(mask):
|
440 |
intervals = []
|
441 |
start = None
|
|
|
509 |
"x1": shadow['end_time'],
|
510 |
"y0": 0,
|
511 |
"y1": 1,
|
512 |
+
"fillcolor": "#ef4444", # Fixed red
|
513 |
"opacity": 0.4,
|
514 |
"line": {"width": 0}
|
515 |
})
|
|
|
558 |
)
|
559 |
}
|
560 |
|
561 |
+
# ------------------ Chart Update Callback ------------------
|
562 |
@app.callback(
|
563 |
[Output(f"graph-{i}", "figure") for i in range(6)],
|
564 |
[Input("store-data", "data")],
|
|
|
578 |
for joint in columns:
|
579 |
all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot)
|
580 |
|
581 |
+
# Generate all charts, no highlight logic
|
582 |
return [
|
583 |
generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows)
|
584 |
for i, joint in enumerate(columns)
|
585 |
]
|
586 |
|
587 |
+
# ------------------ Video Frame Extraction Function ------------------
|
588 |
def get_video_frame(video_path, time_in_seconds):
|
589 |
try:
|
590 |
cap = cv2.VideoCapture(video_path)
|
591 |
if not cap.isOpened():
|
592 |
+
print(f"❌ Cannot open video: {video_path}")
|
593 |
return None
|
594 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
595 |
if fps <= 0:
|
|
|
612 |
else:
|
613 |
return None
|
614 |
except Exception as e:
|
615 |
+
print(f"❌ Exception extracting video frame: {e}")
|
616 |
return None
|
617 |
|
618 |
+
# ------------------ Video Frame Callback ------------------
|
619 |
for i in range(6):
|
620 |
@app.callback(
|
621 |
Output(f"video1-{i}", "src"),
|
|
|
633 |
time_for_plot = timestamps[1:]
|
634 |
video_paths = data["video_paths"]
|
635 |
|
636 |
+
# Determine the time point to display
|
637 |
+
display_time = 0.0 # Default to start time
|
638 |
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
|
639 |
+
# If there is hover data, use hover time
|
640 |
display_time = float(hover_data["points"][0]["x"])
|
641 |
elif len(time_for_plot) > 0:
|
642 |
+
# If no hover data, use the start time of the timeline
|
643 |
display_time = time_for_plot[0]
|
644 |
|
645 |
try:
|
|
|
653 |
print(f"update_video_frames callback error: {e}")
|
654 |
return no_update, no_update
|
655 |
|
656 |
+
# ------------------ Start Application ------------------
|
657 |
if __name__ == "__main__":
|
658 |
app.run(debug=True, host='0.0.0.0', port=7860)
|