zijian2022 commited on
Commit
f4bb1fe
·
verified ·
1 Parent(s): 45853d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -309
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # ------------------ 导入库 ------------------
2
  import dash
3
- from dash import dcc, html, Input, Output, State, callback_context, no_update
4
  import plotly.graph_objects as go
5
  import pandas as pd
6
  import numpy as np
@@ -11,19 +11,11 @@ import requests
11
  import json
12
  import tempfile
13
  import os
14
- from pathlib import Path
15
- from typing import Tuple, Optional
16
  from urllib.parse import urljoin
17
  import subprocess
18
- import shutil
19
-
20
-
21
- # ------------------ 下载数据 ------------------
22
- DOWNLOAD_DIR = tempfile.mkdtemp()
23
 
 
24
  class RemoteDatasetLoader:
25
- """从 Hugging Face Hub 远程加载数据集的类"""
26
-
27
  def __init__(self, repo_id: str, timeout: int = 30):
28
  self.repo_id = repo_id
29
  self.timeout = timeout
@@ -45,9 +37,27 @@ class RemoteDatasetLoader:
45
  return episode
46
  raise ValueError(f"Episode {episode_id} not found")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def _download_video(self, video_url: str, save_path: str) -> str:
49
  response = requests.get(video_url, timeout=self.timeout, stream=True)
50
  response.raise_for_status()
 
 
 
51
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
52
  with open(save_path, 'wb') as f:
53
  for chunk in response.iter_content(chunk_size=8192):
@@ -55,10 +65,10 @@ class RemoteDatasetLoader:
55
  return save_path
56
 
57
  def load_episode_data(self, episode_id: int,
58
- video_keys: Optional[list] = None,
59
- download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
60
  dataset_info = self._get_dataset_info()
61
- episode_info = self._get_episode_info(episode_id)
62
 
63
  if download_dir is None:
64
  download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
@@ -79,10 +89,14 @@ class RemoteDatasetLoader:
79
  )
80
  video_filename = f"episode_{episode_id}_{video_key}.mp4"
81
  local_path = os.path.join(download_dir, video_filename)
 
 
 
 
 
82
  try:
83
  downloaded_path = self._download_video(video_url, local_path)
84
  video_paths.append(downloaded_path)
85
- print(f"Downloaded video {i+1}: {downloaded_path}")
86
  except Exception as e:
87
  print(f"Failed to download video {video_key}: {e}")
88
  video_paths.append(video_url)
@@ -93,17 +107,13 @@ class RemoteDatasetLoader:
93
  )
94
  try:
95
  df = pd.read_parquet(data_url)
96
- print(f"Loaded data with {len(df)} rows and {len(df.columns)} columns")
97
  except Exception as e:
98
  print(f"Failed to load data: {e}")
99
  df = pd.DataFrame()
100
 
101
  return video_paths, df
102
 
103
-
104
- # ------------------ 视频重编码函数 ------------------
105
  def check_ffmpeg_available():
106
- """检查ffmpeg是否可用"""
107
  try:
108
  result = subprocess.run(['ffmpeg', '-version'],
109
  capture_output=True, text=True, timeout=5)
@@ -112,13 +122,11 @@ def check_ffmpeg_available():
112
  return False
113
 
114
  def get_video_codec_info(video_path):
115
- """获取视频编码信息"""
116
  try:
117
  result = subprocess.run([
118
  'ffprobe', '-v', 'quiet', '-print_format', 'json',
119
  '-show_streams', video_path
120
  ], capture_output=True, text=True, timeout=10)
121
-
122
  if result.returncode == 0:
123
  info = json.loads(result.stdout)
124
  for stream in info.get('streams', []):
@@ -126,43 +134,32 @@ def get_video_codec_info(video_path):
126
  return stream.get('codec_name', 'unknown')
127
  except Exception as e:
128
  print(f"获取视频编码信息失败: {e}")
129
-
130
  return 'unknown'
131
 
132
  def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
133
- """将视频重编码为H.264格式"""
134
  if output_path is None:
135
  base_name = os.path.splitext(input_path)[0]
136
  output_path = f"{base_name}_h264.mp4"
137
-
138
- # 根据质量选择参数
139
  quality_params = {
140
  'fast': ['-preset', 'ultrafast', '-crf', '28'],
141
  'medium': ['-preset', 'medium', '-crf', '23'],
142
  'high': ['-preset', 'slow', '-crf', '18']
143
  }
144
-
145
  params = quality_params.get(quality, quality_params['medium'])
146
-
147
  try:
148
  cmd = [
149
  'ffmpeg', '-i', input_path,
150
- '-c:v', 'libx264', # 使用H.264编码器
151
- '-c:a', 'aac', # 音频编码器
152
- '-movflags', '+faststart', # 优化网络播放
153
- '-y', # 覆盖输出文件
154
  ] + params + [output_path]
155
-
156
- print(f"重编码视频: {input_path} -> {output_path}")
157
  result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
158
-
159
  if result.returncode == 0:
160
- print(f"重编码成功: {output_path}")
161
  return output_path
162
  else:
163
  print(f"重编码失败: {result.stderr}")
164
  return input_path
165
-
166
  except subprocess.TimeoutExpired:
167
  print("重编码超时")
168
  return input_path
@@ -171,122 +168,105 @@ def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
171
  return input_path
172
 
173
  def process_video_for_compatibility(video_path):
174
- """处理视频以确保兼容性"""
175
  if not os.path.exists(video_path):
176
  print(f"视频文件不存在: {video_path}")
177
  return video_path
178
-
179
- # 检查ffmpeg是否可用
180
  if not check_ffmpeg_available():
181
  print("ffmpeg不可用,跳过重编码")
182
  return video_path
183
-
184
- # 获取视频编码信息
185
  codec = get_video_codec_info(video_path)
186
- print(f"视频编码格式: {codec}")
187
-
188
- # 如果是AV1或其他不兼容的编码,重编码为H.264
189
  if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
190
- print(f"检测到不兼容的编码格式 ({codec}),开始重编码...")
191
  reencoded_path = reencode_video_to_h264(video_path, quality='fast')
192
-
193
- # 检查重编码后的文件是否存在且大小合理
194
  if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
195
  return reencoded_path
196
  else:
197
  print("重编码失败,使用原始文件")
198
  return video_path
199
  else:
200
- print(f"视频编码 ({codec}) 兼容,无需重编码")
201
  return video_path
202
 
203
-
204
  def load_remote_dataset(repo_id: str,
205
  episode_id: int = 0,
206
- video_keys: Optional[list] = None,
207
- download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
208
  loader = RemoteDatasetLoader(repo_id)
209
  video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir)
210
-
211
- # 处理视频兼容性
212
  processed_video_paths = []
213
  for video_path in video_paths:
214
  processed_path = process_video_for_compatibility(video_path)
215
  processed_video_paths.append(processed_path)
216
-
217
  return processed_video_paths, df
218
 
219
-
220
- # ------------------ 加载数据 ------------------
221
- print("正在加载数据集...")
222
- video_paths, data_df = load_remote_dataset(
223
- repo_id="zijian2022/sortingtest",
224
- episode_id=0,
225
- download_dir="./downloaded_videos"
226
- )
227
-
228
- df = data_df
229
- columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
230
- timestamps = df["timestamp"].values
231
- delta_t = np.diff(timestamps)
232
- time_for_plot = timestamps[1:]
233
- action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
234
-
235
- # ------------------ 视频路径 ------------------
236
- video_path_1 = video_paths[0]
237
- video_path_2 = video_paths[1]
238
-
239
- print(f"视频路径1: {video_path_1}")
240
- print(f"视频路径2: {video_path_2}")
241
-
242
  # ------------------ Dash 初始化 ------------------
243
- app = dash.Dash(__name__)
244
  server = app.server
245
 
246
- # ------------------ 全局变量存储阴影信息 ------------------
247
- all_shadows = {} # 存储所有关节的阴影信息
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- # ------------------ 视频帧提取函数 ------------------
250
- def get_video_frame(video_path, time_in_seconds):
251
- """从视频中提取指定时间的帧"""
 
 
 
 
 
 
252
  try:
253
- cap = cv2.VideoCapture(video_path)
254
- if not cap.isOpened():
255
- print(f"❌ 无法打开视频: {video_path}")
256
- return None
257
-
258
- fps = cap.get(cv2.CAP_PROP_FPS)
259
- if fps <= 0:
260
- print(f"❌ 无法获取视频帧率: {video_path}")
261
- cap.release()
262
- return None
263
-
264
- frame_num = int(time_in_seconds * fps)
265
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
266
- success, frame = cap.read()
267
- cap.release()
268
-
269
- if success and frame is not None:
270
- # 调整图像大小以减少数据传输
271
- height, width = frame.shape[:2]
272
- if width > 640: # 如果宽度大于640,则缩放
273
- new_width = 640
274
- new_height = int(height * (new_width / width))
275
- frame = cv2.resize(frame, (new_width, new_height))
276
-
277
- # 编码为JPEG
278
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85] # 85%质量
279
- _, buffer = cv2.imencode('.jpg', frame, encode_param)
280
- encoded = base64.b64encode(buffer).decode('utf-8')
281
- return f"data:image/jpeg;base64,{encoded}"
282
- else:
283
- print(f"❌ 无法读取帧: {video_path}, 时间: {time_in_seconds}s")
284
- return None
285
-
286
  except Exception as e:
287
- print(f" 提取视频帧异常: {e}")
288
- return None
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  def find_intervals(mask):
291
  intervals = []
292
  start = None
@@ -300,25 +280,17 @@ def find_intervals(mask):
300
  intervals.append((start, len(mask) - 1))
301
  return intervals
302
 
303
- def get_shadow_info(joint_name):
304
- """获取特定关节的所有红色阴影信息"""
305
  angles = action_df[joint_name].values
306
  velocity = np.diff(angles) / delta_t
307
-
308
  smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
309
  smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
310
-
311
- # 参数
312
  vel_threshold = 0.5
313
  highlight_width = 3
314
  k = 2
315
-
316
  shadows = []
317
-
318
- # 低速区间阴影
319
  low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
320
  low_speed_intervals = find_intervals(low_speed_mask)
321
-
322
  for start, end in low_speed_intervals:
323
  if end - start + 1 <= k:
324
  shadows.append({
@@ -328,8 +300,6 @@ def get_shadow_info(joint_name):
328
  'start_idx': start,
329
  'end_idx': end
330
  })
331
-
332
- # 最大值阴影
333
  max_idx = np.argmax(smoothed_angle)
334
  s_max = max(0, max_idx - highlight_width)
335
  e_max = min(len(time_for_plot) - 1, max_idx + highlight_width)
@@ -340,8 +310,6 @@ def get_shadow_info(joint_name):
340
  'start_idx': s_max,
341
  'end_idx': e_max
342
  })
343
-
344
- # 最小值阴影
345
  min_idx = np.argmin(smoothed_angle)
346
  s_min = max(0, min_idx - highlight_width)
347
  e_min = min(len(time_for_plot) - 1, min_idx + highlight_width)
@@ -352,57 +320,28 @@ def get_shadow_info(joint_name):
352
  'start_idx': s_min,
353
  'end_idx': e_min
354
  })
355
-
356
  return shadows
357
 
358
  def is_hover_in_shadow(hover_time, shadows):
359
- """检查hover时间是否在任何阴影内"""
360
  for shadow in shadows:
361
  if shadow['start_time'] <= hover_time <= shadow['end_time']:
362
  return True
363
  return False
364
 
365
  def find_shadows_in_range(shadows, start_time, end_time):
366
- """找到指定时间范围内的所有阴影"""
367
  shadows_in_range = []
368
  for shadow in shadows:
369
- # 检查阴影是否与指定范围有重叠
370
  if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time):
371
  shadows_in_range.append(shadow)
372
  return shadows_in_range
373
 
374
- # 预计算所有关节的阴影信息
375
- print("正在预计算阴影信息...")
376
- for joint in columns:
377
- all_shadows[joint] = get_shadow_info(joint)
378
-
379
- # ------------------ 图表生成函数 ------------------
380
- def generate_joint_graph(joint_name, idx, highlighted_shadows=None):
381
  angles = action_df[joint_name].values
382
  velocity = np.diff(angles) / delta_t
383
-
384
  smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
385
  smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
386
-
387
- # 参数
388
- vel_threshold = 0.5
389
- highlight_width = 3
390
- k = 2
391
-
392
- # 找低速区间
393
- low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
394
- low_speed_intervals = find_intervals(low_speed_mask)
395
-
396
- # 找最大最小点
397
- max_idx = np.argmax(smoothed_angle)
398
- min_idx = np.argmin(smoothed_angle)
399
-
400
  shapes = []
401
-
402
- # 获取当前关节的阴影信息
403
  current_shadows = all_shadows[joint_name]
404
-
405
- # 正常的红色阴影
406
  for shadow in current_shadows:
407
  is_highlighted = False
408
  if highlighted_shadows:
@@ -411,10 +350,8 @@ def generate_joint_graph(joint_name, idx, highlighted_shadows=None):
411
  shadow['end_time'] == h_shadow['end_time']):
412
  is_highlighted = True
413
  break
414
-
415
  color = "blue" if is_highlighted else "red"
416
  opacity = 0.6 if is_highlighted else 0.3
417
-
418
  shapes.append({
419
  "type": "rect",
420
  "xref": "x",
@@ -427,163 +364,138 @@ def generate_joint_graph(joint_name, idx, highlighted_shadows=None):
427
  "opacity": opacity,
428
  "line": {"width": 0}
429
  })
430
-
431
- return dcc.Graph(
432
- id=f"graph-{idx}",
433
- figure={
434
- "data": [
435
- go.Scatter(
436
- x=time_for_plot,
437
- y=smoothed_angle,
438
- name="Angle",
439
- line=dict(color='orange')
440
- )
441
- ],
442
- "layout": go.Layout(
443
- title=joint_name,
444
- xaxis={"title": "Time (s)"},
445
- yaxis={"title": "Angle (deg)"},
446
- shapes=shapes,
447
- hovermode="x unified",
448
- height=250,
449
- margin=dict(t=30, b=30, l=50, r=50),
450
- showlegend=False,
451
  )
452
- },
453
- style={"height": "250px"}
454
- )
455
-
456
- # ------------------ 布局 ------------------
457
- rows = []
458
-
459
- # 关节图 + 双视频帧
460
- for i, joint in enumerate(columns):
461
- rows.append(html.Div([
462
- html.Div(generate_joint_graph(joint, i), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}),
463
- html.Div([
464
- html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}),
465
- html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"})
466
- ], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
467
- ], style={"marginBottom": "15px"}))
468
-
469
- # 设置 layout
470
- app.layout = html.Div([
471
- html.H1("机器人数据可视化 - 视频兼容性优化", style={"textAlign": "center", "marginBottom": "20px"}),
472
- html.Div(rows)
473
- ])
474
 
475
- # ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
476
  @app.callback(
477
  [Output(f"graph-{i}", "figure") for i in range(6)],
478
- [Input(f"graph-{i}", "hoverData") for i in range(6)],
479
- [State(f"graph-{i}", "figure") for i in range(6)],
480
  )
481
- def update_shadow_highlighting(*args):
482
- hover_datas = args[:6]
483
- current_figures = args[6:]
484
-
485
- ctx = dash.callback_context
486
-
487
- # 检查是否有hover触发
488
- if not ctx.triggered:
489
- return [no_update] * 6
490
-
491
- trigger_id = ctx.triggered[0]['prop_id']
492
-
493
- # 如果不是hover触发,不更新
494
- if 'hoverData' not in trigger_id:
495
- return [no_update] * 6
496
-
497
- # 提取触发的图表索引
498
- graph_idx = int(trigger_id.split('-')[1].split('.')[0])
499
- hover_data = hover_datas[graph_idx]
500
-
501
- # 如果没有hover数据,恢复正常状态
502
- if not hover_data or "points" not in hover_data or len(hover_data["points"]) == 0:
503
- updated_figures = []
504
- for i, joint in enumerate(columns):
505
- updated_figures.append(generate_joint_graph(joint, i).figure)
506
- return updated_figures
507
-
508
- try:
509
- hover_time = float(hover_data["points"][0]["x"])
510
- triggered_joint = columns[graph_idx]
511
-
512
- # 检查hover是否在红色阴影内
513
- if not is_hover_in_shadow(hover_time, all_shadows[triggered_joint]):
514
- # 如果不在阴影内,恢复正常状态
515
- updated_figures = []
516
- for i, joint in enumerate(columns):
517
- updated_figures.append(generate_joint_graph(joint, i).figure)
518
- return updated_figures
519
-
520
- # 找到hover时间对应的时间戳索引
521
- hover_idx = np.searchsorted(time_for_plot, hover_time)
522
-
523
- # 计算前后10个时间戳的范围
524
- start_idx = max(0, hover_idx - 20)
525
- end_idx = min(len(time_for_plot) - 1, hover_idx + 20)
526
- start_time = time_for_plot[start_idx]
527
- end_time = time_for_plot[end_idx]
528
-
529
- # 为每个关节生成更新的图表
530
- updated_figures = []
531
- for i, joint in enumerate(columns):
532
- # 找到该关节在指定时间范围内的阴影
533
- shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time)
534
-
535
- # 生成带有高亮的图表
536
- updated_figure = generate_joint_graph(joint, i, shadows_in_range)
537
- updated_figures.append(updated_figure.figure)
538
-
539
- return updated_figures
540
-
541
- except Exception as e:
542
- print(f"处理阴影高亮异常: {e}")
543
  return [no_update] * 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
- # ------------------ 回调:监听 hoverData 更新视频帧 ------------------
546
- @app.callback(
547
- [Output(f"video1-{i}", "src") for i in range(6)] +
548
- [Output(f"video2-{i}", "src") for i in range(6)],
549
- [Input(f"graph-{i}", "hoverData") for i in range(6)]
550
- )
551
- def update_video_frames(*args):
552
- hover_datas = args
553
-
554
- # 获取触发回调的上下文
555
- ctx = dash.callback_context
556
-
557
  try:
558
- # 检查是否有hover触发了回调
559
- if ctx.triggered:
560
- trigger_id = ctx.triggered[0]['prop_id']
561
-
562
- # 如果是图表hover触发的
563
- if 'hoverData' in trigger_id:
564
- # 从trigger_id中提取图表索引
565
- graph_idx = int(trigger_id.split('-')[1].split('.')[0])
566
- hover_data = hover_datas[graph_idx]
567
-
568
- if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
569
- try:
570
- hover_time = float(hover_data["points"][0]["x"])
571
- frame1 = get_video_frame(video_path_1, hover_time)
572
- frame2 = get_video_frame(video_path_2, hover_time)
573
-
574
- # 如果成功获取帧,返回所有视频的帧
575
- if frame1 and frame2:
576
- return [frame1]*6 + [frame2]*6
577
- except Exception as e:
578
- print(f"处理hover数据异常: {e}")
579
-
580
- return [no_update]*12
581
-
582
  except Exception as e:
583
- print(f"update_video_frames回调函数异常: {e}")
584
- return [no_update]*12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  # ------------------ 启动应用 ------------------
587
  if __name__ == "__main__":
588
- print("应用启动中...")
589
  app.run(debug=True, host='0.0.0.0', port=7860)
 
1
  # ------------------ 导入库 ------------------
2
  import dash
3
+ from dash import dcc, html, Input, Output, State, no_update
4
  import plotly.graph_objects as go
5
  import pandas as pd
6
  import numpy as np
 
11
  import json
12
  import tempfile
13
  import os
 
 
14
  from urllib.parse import urljoin
15
  import subprocess
 
 
 
 
 
16
 
17
+ # ------------------ 数据下载与处理 ------------------
18
  class RemoteDatasetLoader:
 
 
19
  def __init__(self, repo_id: str, timeout: int = 30):
20
  self.repo_id = repo_id
21
  self.timeout = timeout
 
37
  return episode
38
  raise ValueError(f"Episode {episode_id} not found")
39
 
40
+ def _is_valid_mp4(self, file_path):
41
+ if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100:
42
+ return False
43
+ # 用ffprobe检查是否为有效mp4
44
+ try:
45
+ result = subprocess.run([
46
+ 'ffprobe', '-v', 'error', '-select_streams', 'v:0',
47
+ '-show_entries', 'stream=codec_name', '-of', 'default=noprint_wrappers=1:nokey=1', file_path
48
+ ], capture_output=True, text=True, timeout=10)
49
+ if result.returncode == 0 and '264' in result.stdout:
50
+ return True
51
+ except Exception as e:
52
+ print(f"ffprobe检查视频失败: {e}")
53
+ return False
54
+
55
  def _download_video(self, video_url: str, save_path: str) -> str:
56
  response = requests.get(video_url, timeout=self.timeout, stream=True)
57
  response.raise_for_status()
58
+ # 检查Content-Type
59
+ if 'video' not in response.headers.get('Content-Type', ''):
60
+ raise ValueError(f"URL {video_url} 返回的不是视频内容,Content-Type: {response.headers.get('Content-Type')}")
61
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
62
  with open(save_path, 'wb') as f:
63
  for chunk in response.iter_content(chunk_size=8192):
 
65
  return save_path
66
 
67
  def load_episode_data(self, episode_id: int,
68
+ video_keys=None,
69
+ download_dir=None):
70
  dataset_info = self._get_dataset_info()
71
+ self._get_episode_info(episode_id) # 检查episode是否存在
72
 
73
  if download_dir is None:
74
  download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
 
89
  )
90
  video_filename = f"episode_{episode_id}_{video_key}.mp4"
91
  local_path = os.path.join(download_dir, video_filename)
92
+ # 优先加载本地有效mp4
93
+ if self._is_valid_mp4(local_path):
94
+ print(f"本地已存在有效视频: {local_path}")
95
+ video_paths.append(local_path)
96
+ continue
97
  try:
98
  downloaded_path = self._download_video(video_url, local_path)
99
  video_paths.append(downloaded_path)
 
100
  except Exception as e:
101
  print(f"Failed to download video {video_key}: {e}")
102
  video_paths.append(video_url)
 
107
  )
108
  try:
109
  df = pd.read_parquet(data_url)
 
110
  except Exception as e:
111
  print(f"Failed to load data: {e}")
112
  df = pd.DataFrame()
113
 
114
  return video_paths, df
115
 
 
 
116
  def check_ffmpeg_available():
 
117
  try:
118
  result = subprocess.run(['ffmpeg', '-version'],
119
  capture_output=True, text=True, timeout=5)
 
122
  return False
123
 
124
  def get_video_codec_info(video_path):
 
125
  try:
126
  result = subprocess.run([
127
  'ffprobe', '-v', 'quiet', '-print_format', 'json',
128
  '-show_streams', video_path
129
  ], capture_output=True, text=True, timeout=10)
 
130
  if result.returncode == 0:
131
  info = json.loads(result.stdout)
132
  for stream in info.get('streams', []):
 
134
  return stream.get('codec_name', 'unknown')
135
  except Exception as e:
136
  print(f"获取视频编码信息失败: {e}")
 
137
  return 'unknown'
138
 
139
  def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
 
140
  if output_path is None:
141
  base_name = os.path.splitext(input_path)[0]
142
  output_path = f"{base_name}_h264.mp4"
 
 
143
  quality_params = {
144
  'fast': ['-preset', 'ultrafast', '-crf', '28'],
145
  'medium': ['-preset', 'medium', '-crf', '23'],
146
  'high': ['-preset', 'slow', '-crf', '18']
147
  }
 
148
  params = quality_params.get(quality, quality_params['medium'])
 
149
  try:
150
  cmd = [
151
  'ffmpeg', '-i', input_path,
152
+ '-c:v', 'libx264',
153
+ '-c:a', 'aac',
154
+ '-movflags', '+faststart',
155
+ '-y',
156
  ] + params + [output_path]
 
 
157
  result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
 
158
  if result.returncode == 0:
 
159
  return output_path
160
  else:
161
  print(f"重编码失败: {result.stderr}")
162
  return input_path
 
163
  except subprocess.TimeoutExpired:
164
  print("重编码超时")
165
  return input_path
 
168
  return input_path
169
 
170
  def process_video_for_compatibility(video_path):
 
171
  if not os.path.exists(video_path):
172
  print(f"视频文件不存在: {video_path}")
173
  return video_path
 
 
174
  if not check_ffmpeg_available():
175
  print("ffmpeg不可用,跳过重编码")
176
  return video_path
 
 
177
  codec = get_video_codec_info(video_path)
 
 
 
178
  if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
 
179
  reencoded_path = reencode_video_to_h264(video_path, quality='fast')
 
 
180
  if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
181
  return reencoded_path
182
  else:
183
  print("重编码失败,使用原始文件")
184
  return video_path
185
  else:
 
186
  return video_path
187
 
 
188
  def load_remote_dataset(repo_id: str,
189
  episode_id: int = 0,
190
+ video_keys=None,
191
+ download_dir=None):
192
  loader = RemoteDatasetLoader(repo_id)
193
  video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir)
 
 
194
  processed_video_paths = []
195
  for video_path in video_paths:
196
  processed_path = process_video_for_compatibility(video_path)
197
  processed_video_paths.append(processed_path)
 
198
  return processed_video_paths, df
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # ------------------ Dash 初始化 ------------------
201
+ app = dash.Dash(__name__, suppress_callback_exceptions=True)
202
  server = app.server
203
 
204
+ # ------------------ 页面布局 ------------------
205
+ app.layout = html.Div([
206
+ html.H1("机器人数据可视化 - 视频兼容性优化", style={"textAlign": "center", "marginBottom": "20px"}),
207
+ html.Div([
208
+ html.Label("repo_id:"),
209
+ dcc.Input(id="input-repo-id", type="text", value="zijian2022/sortingtest", style={"width": "300px"}),
210
+ html.Label("episode_id:", style={"marginLeft": "20px"}),
211
+ dcc.Input(id="input-episode-id", type="number", value=0, min=0, style={"width": "80px"}),
212
+ html.Button("加载", id="btn-load", n_clicks=0, style={"marginLeft": "20px"}),
213
+ ], style={"textAlign": "center", "marginBottom": "30px"}),
214
+ dcc.Loading(
215
+ id="loading",
216
+ type="default",
217
+ children=dcc.Store(id="store-data")
218
+ ),
219
+ html.Div(id="main-content")
220
+ ])
221
 
222
+ # ------------------ 数据加载回调 ------------------
223
+ @app.callback(
224
+ Output("store-data", "data"),
225
+ Input("btn-load", "n_clicks"),
226
+ State("input-repo-id", "value"),
227
+ State("input-episode-id", "value"),
228
+ prevent_initial_call=True
229
+ )
230
+ def load_data_callback(n_clicks, repo_id, episode_id):
231
  try:
232
+ video_paths, data_df = load_remote_dataset(
233
+ repo_id=repo_id,
234
+ episode_id=int(episode_id),
235
+ download_dir="./downloaded_videos"
236
+ )
237
+ if data_df is None or data_df.empty:
238
+ return {}
239
+ return {
240
+ "video_paths": video_paths,
241
+ "data_df": data_df.to_dict("records"),
242
+ "columns": ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"],
243
+ "timestamps": data_df["timestamp"].tolist()
244
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  except Exception as e:
246
+ print(f"数据加载异常: {e}")
247
+ return {}
248
 
249
+ # ------------------ 主内容渲染回调 ------------------
250
+ @app.callback(
251
+ Output("main-content", "children"),
252
+ Input("store-data", "data")
253
+ )
254
+ def update_main_content(data):
255
+ if not data or "data_df" not in data or len(data["data_df"]) == 0:
256
+ return html.Div("请点击上方“加载”按钮获取数据", style={"textAlign": "center", "color": "red"})
257
+ columns = data["columns"]
258
+ rows = []
259
+ for i, joint in enumerate(columns):
260
+ rows.append(html.Div([
261
+ html.Div(dcc.Graph(id=f"graph-{i}"), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}),
262
+ html.Div([
263
+ html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}),
264
+ html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"})
265
+ ], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
266
+ ], style={"marginBottom": "15px"}))
267
+ return html.Div(rows)
268
+
269
+ # ------------------ 阴影与高亮工具函数 ------------------
270
  def find_intervals(mask):
271
  intervals = []
272
  start = None
 
280
  intervals.append((start, len(mask) - 1))
281
  return intervals
282
 
283
+ def get_shadow_info(joint_name, action_df, delta_t, time_for_plot):
 
284
  angles = action_df[joint_name].values
285
  velocity = np.diff(angles) / delta_t
 
286
  smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
287
  smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
 
 
288
  vel_threshold = 0.5
289
  highlight_width = 3
290
  k = 2
 
291
  shadows = []
 
 
292
  low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
293
  low_speed_intervals = find_intervals(low_speed_mask)
 
294
  for start, end in low_speed_intervals:
295
  if end - start + 1 <= k:
296
  shadows.append({
 
300
  'start_idx': start,
301
  'end_idx': end
302
  })
 
 
303
  max_idx = np.argmax(smoothed_angle)
304
  s_max = max(0, max_idx - highlight_width)
305
  e_max = min(len(time_for_plot) - 1, max_idx + highlight_width)
 
310
  'start_idx': s_max,
311
  'end_idx': e_max
312
  })
 
 
313
  min_idx = np.argmin(smoothed_angle)
314
  s_min = max(0, min_idx - highlight_width)
315
  e_min = min(len(time_for_plot) - 1, min_idx + highlight_width)
 
320
  'start_idx': s_min,
321
  'end_idx': e_min
322
  })
 
323
  return shadows
324
 
325
  def is_hover_in_shadow(hover_time, shadows):
 
326
  for shadow in shadows:
327
  if shadow['start_time'] <= hover_time <= shadow['end_time']:
328
  return True
329
  return False
330
 
331
  def find_shadows_in_range(shadows, start_time, end_time):
 
332
  shadows_in_range = []
333
  for shadow in shadows:
 
334
  if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time):
335
  shadows_in_range.append(shadow)
336
  return shadows_in_range
337
 
338
+ def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all_shadows, highlighted_shadows=None):
 
 
 
 
 
 
339
  angles = action_df[joint_name].values
340
  velocity = np.diff(angles) / delta_t
 
341
  smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
342
  smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  shapes = []
 
 
344
  current_shadows = all_shadows[joint_name]
 
 
345
  for shadow in current_shadows:
346
  is_highlighted = False
347
  if highlighted_shadows:
 
350
  shadow['end_time'] == h_shadow['end_time']):
351
  is_highlighted = True
352
  break
 
353
  color = "blue" if is_highlighted else "red"
354
  opacity = 0.6 if is_highlighted else 0.3
 
355
  shapes.append({
356
  "type": "rect",
357
  "xref": "x",
 
364
  "opacity": opacity,
365
  "line": {"width": 0}
366
  })
367
+ return {
368
+ "data": [
369
+ go.Scatter(
370
+ x=time_for_plot,
371
+ y=smoothed_angle,
372
+ name="Angle",
373
+ line=dict(color='orange')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  )
375
+ ],
376
+ "layout": go.Layout(
377
+ title=joint_name,
378
+ xaxis={"title": "Time (s)"},
379
+ yaxis={"title": "Angle (deg)"},
380
+ shapes=shapes,
381
+ hovermode="x unified",
382
+ height=250,
383
+ margin=dict(t=30, b=30, l=50, r=50),
384
+ showlegend=False,
385
+ )
386
+ }
 
 
 
 
 
 
 
 
 
 
387
 
388
+ # ------------------ 联动高亮回调 ------------------
389
  @app.callback(
390
  [Output(f"graph-{i}", "figure") for i in range(6)],
391
+ [Input("store-data", "data")] + [Input(f"graph-{i}", "hoverData") for i in range(6)],
392
+ prevent_initial_call=True
393
  )
394
+ def update_all_graphs(data, *hover_datas):
395
+ if not data or "data_df" not in data or len(data["data_df"]) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  return [no_update] * 6
397
+ columns = data["columns"]
398
+ df = pd.DataFrame.from_records(data["data_df"])
399
+ action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
400
+ timestamps = df["timestamp"].values
401
+ delta_t = np.diff(timestamps)
402
+ time_for_plot = timestamps[1:]
403
+ all_shadows = {}
404
+ for joint in columns:
405
+ all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot)
406
+
407
+ # 查找是否有任何一个hover落在阴影内
408
+ for idx, hover_data in enumerate(hover_datas):
409
+ if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
410
+ hover_time = float(hover_data["points"][0]["x"])
411
+ triggered_joint = columns[idx]
412
+ if is_hover_in_shadow(hover_time, all_shadows[triggered_joint]):
413
+ hover_idx = np.searchsorted(time_for_plot, hover_time)
414
+ start_idx = max(0, hover_idx - 20)
415
+ end_idx = min(len(time_for_plot) - 1, hover_idx + 20)
416
+ start_time = time_for_plot[start_idx]
417
+ end_time = time_for_plot[end_idx]
418
+ figures = []
419
+ for i, joint in enumerate(columns):
420
+ shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time)
421
+ fig = generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows, shadows_in_range)
422
+ figures.append(fig)
423
+ return figures
424
+ # 没有hover或不在阴影内,全部正常显示
425
+ return [
426
+ generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows)
427
+ for i, joint in enumerate(columns)
428
+ ]
429
 
430
+ # ------------------ 视频帧提取函数 ------------------
431
+ def get_video_frame(video_path, time_in_seconds):
 
 
 
 
 
 
 
 
 
 
432
  try:
433
+ cap = cv2.VideoCapture(video_path)
434
+ if not cap.isOpened():
435
+ print(f"❌ 无法打开视频: {video_path}")
436
+ return None
437
+ fps = cap.get(cv2.CAP_PROP_FPS)
438
+ if fps <= 0:
439
+ cap.release()
440
+ return None
441
+ frame_num = int(time_in_seconds * fps)
442
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
443
+ success, frame = cap.read()
444
+ cap.release()
445
+ if success and frame is not None:
446
+ height, width = frame.shape[:2]
447
+ if width > 640:
448
+ new_width = 640
449
+ new_height = int(height * (new_width / width))
450
+ frame = cv2.resize(frame, (new_width, new_height))
451
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85]
452
+ _, buffer = cv2.imencode('.jpg', frame, encode_param)
453
+ encoded = base64.b64encode(buffer).decode('utf-8')
454
+ return f"data:image/jpeg;base64,{encoded}"
455
+ else:
456
+ return None
457
  except Exception as e:
458
+ print(f" 提取视频帧异常: {e}")
459
+ return None
460
+
461
+ # ------------------ 视频帧回调 ------------------
462
+ for i in range(6):
463
+ @app.callback(
464
+ Output(f"video1-{i}", "src"),
465
+ Output(f"video2-{i}", "src"),
466
+ Input("store-data", "data"),
467
+ Input(f"graph-{i}", "hoverData"),
468
+ prevent_initial_call=True
469
+ )
470
+ def update_video_frames(data, hover_data, idx=i):
471
+ if not data or "data_df" not in data or len(data["data_df"]) == 0:
472
+ return no_update, no_update
473
+ columns = data["columns"]
474
+ df = pd.DataFrame.from_records(data["data_df"])
475
+ timestamps = df["timestamp"].values
476
+ time_for_plot = timestamps[1:]
477
+ video_paths = data["video_paths"]
478
+
479
+ # 确定要显示的时间点
480
+ display_time = 0.0 # 默认显示开始时间
481
+ if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
482
+ # 如果有hover数据,使用hover时间
483
+ display_time = float(hover_data["points"][0]["x"])
484
+ elif len(time_for_plot) > 0:
485
+ # 如果没有hover数据,使用时间轴开始时间
486
+ display_time = time_for_plot[0]
487
+
488
+ try:
489
+ frame1 = get_video_frame(video_paths[0], display_time)
490
+ frame2 = get_video_frame(video_paths[1], display_time)
491
+ if frame1 and frame2:
492
+ return frame1, frame2
493
+ else:
494
+ return no_update, no_update
495
+ except Exception as e:
496
+ print(f"update_video_frames回调函数异常: {e}")
497
+ return no_update, no_update
498
 
499
  # ------------------ 启动应用 ------------------
500
  if __name__ == "__main__":
 
501
  app.run(debug=True, host='0.0.0.0', port=7860)