zijian2022 commited on
Commit
1923cfd
·
verified ·
1 Parent(s): 5970c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -36
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # ------------------ 导入库 ------------------
2
  import dash
3
  from dash import dcc, html, Input, Output, State, no_update
4
  import plotly.graph_objects as go
@@ -14,7 +14,7 @@ import os
14
  from urllib.parse import urljoin
15
  import subprocess
16
 
17
- # ------------------ 数据下载与处理 ------------------
18
  class RemoteDatasetLoader:
19
  def __init__(self, repo_id: str, timeout: int = 30):
20
  self.repo_id = repo_id
@@ -40,7 +40,7 @@ class RemoteDatasetLoader:
40
  def _is_valid_mp4(self, file_path):
41
  if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100:
42
  return False
43
- # ffprobe检查是否为有效mp4
44
  try:
45
  result = subprocess.run([
46
  'ffprobe', '-v', 'error', '-select_streams', 'v:0',
@@ -55,9 +55,9 @@ class RemoteDatasetLoader:
55
  def _download_video(self, video_url: str, save_path: str) -> str:
56
  response = requests.get(video_url, timeout=self.timeout, stream=True)
57
  response.raise_for_status()
58
- # 检查Content-Type
59
  if 'video' not in response.headers.get('Content-Type', ''):
60
- raise ValueError(f"URL {video_url} 返回的不是视频内容,Content-Type: {response.headers.get('Content-Type')}")
61
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
62
  with open(save_path, 'wb') as f:
63
  for chunk in response.iter_content(chunk_size=8192):
@@ -68,7 +68,7 @@ class RemoteDatasetLoader:
68
  video_keys=None,
69
  download_dir=None):
70
  dataset_info = self._get_dataset_info()
71
- self._get_episode_info(episode_id) # 检查episode是否存在
72
 
73
  if download_dir is None:
74
  download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
@@ -81,8 +81,8 @@ class RemoteDatasetLoader:
81
  video_paths = []
82
  chunks_size = dataset_info.get("chunks_size", 1000)
83
 
84
- # 创建repo特定的子目录
85
- repo_name = self.repo_id.replace('/', '_') # / 替换为 _ 避免路径问题
86
  repo_dir = os.path.join(download_dir, repo_name)
87
  os.makedirs(repo_dir, exist_ok=True)
88
 
@@ -94,7 +94,7 @@ class RemoteDatasetLoader:
94
  )
95
  video_filename = f"episode_{episode_id}_{video_key}.mp4"
96
  local_path = os.path.join(repo_dir, video_filename)
97
- # 优先加载本地有效mp4
98
  if self._is_valid_mp4(local_path):
99
  print(f"Local valid video found: {local_path}")
100
  video_paths.append(local_path)
@@ -138,7 +138,7 @@ def get_video_codec_info(video_path):
138
  if stream.get('codec_type') == 'video':
139
  return stream.get('codec_name', 'unknown')
140
  except Exception as e:
141
- print(f"获取视频编码信息失败: {e}")
142
  return 'unknown'
143
 
144
  def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
@@ -163,21 +163,21 @@ def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
163
  if result.returncode == 0:
164
  return output_path
165
  else:
166
- print(f"重编码失败: {result.stderr}")
167
  return input_path
168
  except subprocess.TimeoutExpired:
169
- print("重编码超时")
170
  return input_path
171
  except Exception as e:
172
- print(f"重编码异常: {e}")
173
  return input_path
174
 
175
  def process_video_for_compatibility(video_path):
176
  if not os.path.exists(video_path):
177
- print(f"视频文件不存在: {video_path}")
178
  return video_path
179
  if not check_ffmpeg_available():
180
- print("ffmpeg不可用,跳过重编码")
181
  return video_path
182
  codec = get_video_codec_info(video_path)
183
  if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
@@ -185,7 +185,7 @@ def process_video_for_compatibility(video_path):
185
  if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
186
  return reencoded_path
187
  else:
188
- print("重编码失败,使用原始文件")
189
  return video_path
190
  else:
191
  return video_path
@@ -202,15 +202,15 @@ def load_remote_dataset(repo_id: str,
202
  processed_video_paths.append(processed_path)
203
  return processed_video_paths, df
204
 
205
- # ------------------ Dash 初始化 ------------------
206
  app = dash.Dash(__name__, suppress_callback_exceptions=True)
207
  server = app.server
208
 
209
- # ------------------ 页面布局 ------------------
210
  app.layout = html.Div([
211
  # Header with gradient background
212
  html.Div([
213
- html.H1("Robot Data Visualization",
214
  style={
215
  "textAlign": "center",
216
  "marginBottom": "10px",
@@ -340,7 +340,7 @@ app.layout = html.Div([
340
  "padding": "0"
341
  })
342
 
343
- # ------------------ 数据加载回调 ------------------
344
  @app.callback(
345
  Output("store-data", "data"),
346
  Input("btn-load", "n_clicks"),
@@ -367,7 +367,7 @@ def load_data_callback(n_clicks, repo_id, episode_id):
367
  print(f"Data loading error: {e}")
368
  return {}
369
 
370
- # ------------------ 主内容渲染回调 ------------------
371
  @app.callback(
372
  Output("main-content", "children"),
373
  Input("store-data", "data")
@@ -389,7 +389,7 @@ def update_main_content(data):
389
  rows = []
390
  for i, joint in enumerate(columns):
391
  rows.append(html.Div([
392
- # 关节图 - 左侧50%
393
  html.Div([
394
  dcc.Graph(id=f"graph-{i}")
395
  ], style={
@@ -401,7 +401,7 @@ def update_main_content(data):
401
  "border": "1px solid #e9ecef",
402
  "marginRight": "2%"
403
  }),
404
- # 视频区域 - 右侧48%
405
  html.Div([
406
  html.Img(id=f"video1-{i}", style={
407
  "width": "49%",
@@ -435,7 +435,7 @@ def update_main_content(data):
435
  }))
436
  return html.Div(rows)
437
 
438
- # ------------------ 阴影与高亮工具函数 ------------------
439
  def find_intervals(mask):
440
  intervals = []
441
  start = None
@@ -509,7 +509,7 @@ def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all
509
  "x1": shadow['end_time'],
510
  "y0": 0,
511
  "y1": 1,
512
- "fillcolor": "#ef4444", # 固定红色
513
  "opacity": 0.4,
514
  "line": {"width": 0}
515
  })
@@ -558,7 +558,7 @@ def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all
558
  )
559
  }
560
 
561
- # ------------------ 图表更新回调 ------------------
562
  @app.callback(
563
  [Output(f"graph-{i}", "figure") for i in range(6)],
564
  [Input("store-data", "data")],
@@ -578,18 +578,18 @@ def update_all_graphs(data):
578
  for joint in columns:
579
  all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot)
580
 
581
- # 生成所有图表,无高亮逻辑
582
  return [
583
  generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows)
584
  for i, joint in enumerate(columns)
585
  ]
586
 
587
- # ------------------ 视频帧提取函数 ------------------
588
  def get_video_frame(video_path, time_in_seconds):
589
  try:
590
  cap = cv2.VideoCapture(video_path)
591
  if not cap.isOpened():
592
- print(f"❌ 无法打开视频: {video_path}")
593
  return None
594
  fps = cap.get(cv2.CAP_PROP_FPS)
595
  if fps <= 0:
@@ -612,10 +612,10 @@ def get_video_frame(video_path, time_in_seconds):
612
  else:
613
  return None
614
  except Exception as e:
615
- print(f"❌ 提取视频帧异常: {e}")
616
  return None
617
 
618
- # ------------------ 视频帧回调 ------------------
619
  for i in range(6):
620
  @app.callback(
621
  Output(f"video1-{i}", "src"),
@@ -633,13 +633,13 @@ for i in range(6):
633
  time_for_plot = timestamps[1:]
634
  video_paths = data["video_paths"]
635
 
636
- # 确定要显示的时间点
637
- display_time = 0.0 # 默认显示开始时间
638
  if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
639
- # 如果有hover数据,使用hover时间
640
  display_time = float(hover_data["points"][0]["x"])
641
  elif len(time_for_plot) > 0:
642
- # 如果没有hover数据,使用时间轴开始时间
643
  display_time = time_for_plot[0]
644
 
645
  try:
@@ -653,6 +653,6 @@ for i in range(6):
653
  print(f"update_video_frames callback error: {e}")
654
  return no_update, no_update
655
 
656
- # ------------------ 启动应用 ------------------
657
  if __name__ == "__main__":
658
  app.run(debug=True, host='0.0.0.0', port=7860)
 
1
+ # ------------------ Import Libraries ------------------
2
  import dash
3
  from dash import dcc, html, Input, Output, State, no_update
4
  import plotly.graph_objects as go
 
14
  from urllib.parse import urljoin
15
  import subprocess
16
 
17
+ # ------------------ Data Download and Processing ------------------
18
  class RemoteDatasetLoader:
19
  def __init__(self, repo_id: str, timeout: int = 30):
20
  self.repo_id = repo_id
 
40
  def _is_valid_mp4(self, file_path):
41
  if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100:
42
  return False
43
+ # Use ffprobe to check if it is a valid mp4
44
  try:
45
  result = subprocess.run([
46
  'ffprobe', '-v', 'error', '-select_streams', 'v:0',
 
55
  def _download_video(self, video_url: str, save_path: str) -> str:
56
  response = requests.get(video_url, timeout=self.timeout, stream=True)
57
  response.raise_for_status()
58
+ # Check Content-Type
59
  if 'video' not in response.headers.get('Content-Type', ''):
60
+ raise ValueError(f"URL {video_url} does not return video content, Content-Type: {response.headers.get('Content-Type')}")
61
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
62
  with open(save_path, 'wb') as f:
63
  for chunk in response.iter_content(chunk_size=8192):
 
68
  video_keys=None,
69
  download_dir=None):
70
  dataset_info = self._get_dataset_info()
71
+ self._get_episode_info(episode_id) # Check if episode exists
72
 
73
  if download_dir is None:
74
  download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
 
81
  video_paths = []
82
  chunks_size = dataset_info.get("chunks_size", 1000)
83
 
84
+ # Create repo-specific subdirectory
85
+ repo_name = self.repo_id.replace('/', '_') # Replace / with _ to avoid path issues
86
  repo_dir = os.path.join(download_dir, repo_name)
87
  os.makedirs(repo_dir, exist_ok=True)
88
 
 
94
  )
95
  video_filename = f"episode_{episode_id}_{video_key}.mp4"
96
  local_path = os.path.join(repo_dir, video_filename)
97
+ # Prefer loading local valid mp4
98
  if self._is_valid_mp4(local_path):
99
  print(f"Local valid video found: {local_path}")
100
  video_paths.append(local_path)
 
138
  if stream.get('codec_type') == 'video':
139
  return stream.get('codec_name', 'unknown')
140
  except Exception as e:
141
+ print(f"Failed to get video codec info: {e}")
142
  return 'unknown'
143
 
144
  def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
 
163
  if result.returncode == 0:
164
  return output_path
165
  else:
166
+ print(f"Re-encoding failed: {result.stderr}")
167
  return input_path
168
  except subprocess.TimeoutExpired:
169
+ print("Re-encoding timeout")
170
  return input_path
171
  except Exception as e:
172
+ print(f"Re-encoding exception: {e}")
173
  return input_path
174
 
175
  def process_video_for_compatibility(video_path):
176
  if not os.path.exists(video_path):
177
+ print(f"Video file does not exist: {video_path}")
178
  return video_path
179
  if not check_ffmpeg_available():
180
+ print("ffmpeg not available, skipping re-encoding")
181
  return video_path
182
  codec = get_video_codec_info(video_path)
183
  if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
 
185
  if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
186
  return reencoded_path
187
  else:
188
+ print("Re-encoding failed, using original file")
189
  return video_path
190
  else:
191
  return video_path
 
202
  processed_video_paths.append(processed_path)
203
  return processed_video_paths, df
204
 
205
+ # ------------------ Dash Initialization ------------------
206
  app = dash.Dash(__name__, suppress_callback_exceptions=True)
207
  server = app.server
208
 
209
+ # ------------------ Page Layout ------------------
210
  app.layout = html.Div([
211
  # Header with gradient background
212
  html.Div([
213
+ html.H1("Keyframe Identification",
214
  style={
215
  "textAlign": "center",
216
  "marginBottom": "10px",
 
340
  "padding": "0"
341
  })
342
 
343
+ # ------------------ Data Loading Callback ------------------
344
  @app.callback(
345
  Output("store-data", "data"),
346
  Input("btn-load", "n_clicks"),
 
367
  print(f"Data loading error: {e}")
368
  return {}
369
 
370
+ # ------------------ Main Content Rendering Callback ------------------
371
  @app.callback(
372
  Output("main-content", "children"),
373
  Input("store-data", "data")
 
389
  rows = []
390
  for i, joint in enumerate(columns):
391
  rows.append(html.Div([
392
+ # Joint Graph - Left 50%
393
  html.Div([
394
  dcc.Graph(id=f"graph-{i}")
395
  ], style={
 
401
  "border": "1px solid #e9ecef",
402
  "marginRight": "2%"
403
  }),
404
+ # Video Area - Right 48%
405
  html.Div([
406
  html.Img(id=f"video1-{i}", style={
407
  "width": "49%",
 
435
  }))
436
  return html.Div(rows)
437
 
438
+ # ------------------ Shadow and Highlight Utility Functions ------------------
439
  def find_intervals(mask):
440
  intervals = []
441
  start = None
 
509
  "x1": shadow['end_time'],
510
  "y0": 0,
511
  "y1": 1,
512
+ "fillcolor": "#ef4444", # Fixed red
513
  "opacity": 0.4,
514
  "line": {"width": 0}
515
  })
 
558
  )
559
  }
560
 
561
+ # ------------------ Chart Update Callback ------------------
562
  @app.callback(
563
  [Output(f"graph-{i}", "figure") for i in range(6)],
564
  [Input("store-data", "data")],
 
578
  for joint in columns:
579
  all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot)
580
 
581
+ # Generate all charts, no highlight logic
582
  return [
583
  generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows)
584
  for i, joint in enumerate(columns)
585
  ]
586
 
587
+ # ------------------ Video Frame Extraction Function ------------------
588
  def get_video_frame(video_path, time_in_seconds):
589
  try:
590
  cap = cv2.VideoCapture(video_path)
591
  if not cap.isOpened():
592
+ print(f"❌ Cannot open video: {video_path}")
593
  return None
594
  fps = cap.get(cv2.CAP_PROP_FPS)
595
  if fps <= 0:
 
612
  else:
613
  return None
614
  except Exception as e:
615
+ print(f"❌ Exception extracting video frame: {e}")
616
  return None
617
 
618
+ # ------------------ Video Frame Callback ------------------
619
  for i in range(6):
620
  @app.callback(
621
  Output(f"video1-{i}", "src"),
 
633
  time_for_plot = timestamps[1:]
634
  video_paths = data["video_paths"]
635
 
636
+ # Determine the time point to display
637
+ display_time = 0.0 # Default to start time
638
  if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
639
+ # If there is hover data, use hover time
640
  display_time = float(hover_data["points"][0]["x"])
641
  elif len(time_for_plot) > 0:
642
+ # If no hover data, use the start time of the timeline
643
  display_time = time_for_plot[0]
644
 
645
  try:
 
653
  print(f"update_video_frames callback error: {e}")
654
  return no_update, no_update
655
 
656
+ # ------------------ Start Application ------------------
657
  if __name__ == "__main__":
658
  app.run(debug=True, host='0.0.0.0', port=7860)