zijian2022 commited on
Commit
888d740
·
verified ·
1 Parent(s): ab24e1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -68
app.py CHANGED
@@ -14,7 +14,12 @@ import os
14
  from pathlib import Path
15
  from typing import Tuple, Optional
16
  from urllib.parse import urljoin
 
 
 
 
17
  # ------------------ 下载数据 ------------------
 
18
 
19
  class RemoteDatasetLoader:
20
  """从 Hugging Face Hub 远程加载数据集的类"""
@@ -96,21 +101,130 @@ class RemoteDatasetLoader:
96
  return video_paths, df
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def load_remote_dataset(repo_id: str,
100
  episode_id: int = 0,
101
  video_keys: Optional[list] = None,
102
  download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
103
  loader = RemoteDatasetLoader(repo_id)
104
- return loader.load_episode_data(episode_id, video_keys, download_dir)
 
 
 
 
 
 
 
 
105
 
106
 
107
- video_paths, data_df = load_remote_dataset(
108
- repo_id="zijian2022/sortingtest",
109
- episode_id=0,
110
- download_dir="./downloaded_videos"
111
- )
112
  # ------------------ 加载数据 ------------------
113
- #df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
 
 
 
 
 
 
114
  df = data_df
115
  columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
116
  timestamps = df["timestamp"].values
@@ -119,10 +233,12 @@ time_for_plot = timestamps[1:]
119
  action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
120
 
121
  # ------------------ 视频路径 ------------------
122
- #video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
123
- #video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
124
  video_path_1 = video_paths[0]
125
  video_path_2 = video_paths[1]
 
 
 
 
126
  # ------------------ Dash 初始化 ------------------
127
  app = dash.Dash(__name__)
128
  server = app.server
@@ -132,20 +248,43 @@ all_shadows = {} # 存储所有关节的阴影信息
132
 
133
  # ------------------ 视频帧提取函数 ------------------
134
  def get_video_frame(video_path, time_in_seconds):
135
- cap = cv2.VideoCapture(video_path)
136
- if not cap.isOpened():
137
- print(f"❌ 无法打开视频: {video_path}")
138
- return None
139
- fps = cap.get(cv2.CAP_PROP_FPS)
140
- frame_num = int(time_in_seconds * fps)
141
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
142
- success, frame = cap.read()
143
- cap.release()
144
- if success:
145
- _, buffer = cv2.imencode('.jpg', frame)
146
- encoded = base64.b64encode(buffer).decode('utf-8')
147
- return f"data:image/jpeg;base64,{encoded}"
148
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  return None
150
 
151
  def find_intervals(mask):
@@ -233,6 +372,7 @@ def find_shadows_in_range(shadows, start_time, end_time):
233
  return shadows_in_range
234
 
235
  # 预计算所有关节的阴影信息
 
236
  for joint in columns:
237
  all_shadows[joint] = get_shadow_info(joint)
238
 
@@ -326,12 +466,11 @@ for i, joint in enumerate(columns):
326
  ], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
327
  ], style={"marginBottom": "15px"}))
328
 
329
- # 添加定时器和存储组件
330
- rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0))
331
- rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0}))
332
-
333
  # 设置 layout
334
- app.layout = html.Div(rows)
 
 
 
335
 
336
  # ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
337
  @app.callback(
@@ -404,20 +543,13 @@ def update_shadow_highlighting(*args):
404
  return [no_update] * 6
405
 
406
  # ------------------ 回调:监听 hoverData 更新视频帧 ------------------
407
- video_duration = timestamps[-1] - timestamps[0]
408
-
409
  @app.callback(
410
  [Output(f"video1-{i}", "src") for i in range(6)] +
411
- [Output(f"video2-{i}", "src") for i in range(6)] +
412
- [Output("hover-state-store", "data")],
413
- [Input(f"graph-{i}", "hoverData") for i in range(6)] +
414
- [Input("video-playback-interval", "n_intervals")],
415
- [State("hover-state-store", "data")]
416
  )
417
  def update_video_frames(*args):
418
- hover_datas = args[:-2]
419
- interval_count = args[-2]
420
- hover_state = args[-1]
421
 
422
  # 获取触发回调的上下文
423
  ctx = dash.callback_context
@@ -439,43 +571,19 @@ def update_video_frames(*args):
439
  frame1 = get_video_frame(video_path_1, hover_time)
440
  frame2 = get_video_frame(video_path_2, hover_time)
441
 
442
- # 更新hover状态为活跃
443
- new_hover_state = {"active": True, "last_update": interval_count}
444
-
445
  # 如果成功获取帧,返回所有视频的帧
446
  if frame1 and frame2:
447
- return [frame1]*6 + [frame2]*6 + [new_hover_state]
448
  except Exception as e:
449
  print(f"处理hover数据异常: {e}")
450
-
451
- # 如果是interval触发的
452
- if 'video-playback-interval' in trigger_id:
453
- # 检查hover状态是否过期(超过3个interval周期没有更新)
454
- hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3
455
-
456
- if not hover_state.get("active", False) or hover_expired:
457
- # 没有hover或hover已过期时才自动播放
458
- t = timestamps[0] + (interval_count * 0.3) % video_duration
459
- frame1 = get_video_frame(video_path_1, t)
460
- frame2 = get_video_frame(video_path_2, t)
461
-
462
- # 更新hover状态为非活跃
463
- new_hover_state = {"active": False, "last_update": interval_count}
464
-
465
- if frame1 and frame2:
466
- return [frame1]*6 + [frame2]*6 + [new_hover_state]
467
- else:
468
- return [no_update]*12 + [new_hover_state]
469
- else:
470
- # hover仍然活跃时,暂停自动播放
471
- return [no_update]*12 + [hover_state]
472
-
473
- return [no_update]*12 + [hover_state]
474
 
475
  except Exception as e:
476
  print(f"update_video_frames回调函数异常: {e}")
477
- return [no_update]*12 + [hover_state]
478
 
479
  # ------------------ 启动应用 ------------------
480
  if __name__ == "__main__":
481
- app.run(debug=True)
 
 
14
  from pathlib import Path
15
  from typing import Tuple, Optional
16
  from urllib.parse import urljoin
17
+ import subprocess
18
+ import shutil
19
+
20
+
21
  # ------------------ 下载数据 ------------------
22
+ DOWNLOAD_DIR = tempfile.mkdtemp()
23
 
24
  class RemoteDatasetLoader:
25
  """从 Hugging Face Hub 远程加载数据集的类"""
 
101
  return video_paths, df
102
 
103
 
104
+ # ------------------ 视频重编码函数 ------------------
105
+ def check_ffmpeg_available():
106
+ """检查ffmpeg是否可用"""
107
+ try:
108
+ result = subprocess.run(['ffmpeg', '-version'],
109
+ capture_output=True, text=True, timeout=5)
110
+ return result.returncode == 0
111
+ except (subprocess.TimeoutExpired, FileNotFoundError):
112
+ return False
113
+
114
+ def get_video_codec_info(video_path):
115
+ """获取视频编码信息"""
116
+ try:
117
+ result = subprocess.run([
118
+ 'ffprobe', '-v', 'quiet', '-print_format', 'json',
119
+ '-show_streams', video_path
120
+ ], capture_output=True, text=True, timeout=10)
121
+
122
+ if result.returncode == 0:
123
+ info = json.loads(result.stdout)
124
+ for stream in info.get('streams', []):
125
+ if stream.get('codec_type') == 'video':
126
+ return stream.get('codec_name', 'unknown')
127
+ except Exception as e:
128
+ print(f"获取视频编码信息失败: {e}")
129
+
130
+ return 'unknown'
131
+
132
+ def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
133
+ """将视频重编码为H.264格式"""
134
+ if output_path is None:
135
+ base_name = os.path.splitext(input_path)[0]
136
+ output_path = f"{base_name}_h264.mp4"
137
+
138
+ # 根据质量选择参数
139
+ quality_params = {
140
+ 'fast': ['-preset', 'ultrafast', '-crf', '28'],
141
+ 'medium': ['-preset', 'medium', '-crf', '23'],
142
+ 'high': ['-preset', 'slow', '-crf', '18']
143
+ }
144
+
145
+ params = quality_params.get(quality, quality_params['medium'])
146
+
147
+ try:
148
+ cmd = [
149
+ 'ffmpeg', '-i', input_path,
150
+ '-c:v', 'libx264', # 使用H.264编码器
151
+ '-c:a', 'aac', # 音频编码器
152
+ '-movflags', '+faststart', # 优化网络播放
153
+ '-y', # 覆盖输出文件
154
+ ] + params + [output_path]
155
+
156
+ print(f"重编码视频: {input_path} -> {output_path}")
157
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
158
+
159
+ if result.returncode == 0:
160
+ print(f"重编码成功: {output_path}")
161
+ return output_path
162
+ else:
163
+ print(f"重编码失败: {result.stderr}")
164
+ return input_path
165
+
166
+ except subprocess.TimeoutExpired:
167
+ print("重编码超时")
168
+ return input_path
169
+ except Exception as e:
170
+ print(f"重编码异常: {e}")
171
+ return input_path
172
+
173
+ def process_video_for_compatibility(video_path):
174
+ """处理视频以确保兼容性"""
175
+ if not os.path.exists(video_path):
176
+ print(f"视频文件不存在: {video_path}")
177
+ return video_path
178
+
179
+ # 检查ffmpeg是否可用
180
+ if not check_ffmpeg_available():
181
+ print("ffmpeg不可用,跳过重编码")
182
+ return video_path
183
+
184
+ # 获取视频编码信息
185
+ codec = get_video_codec_info(video_path)
186
+ print(f"视频编码格式: {codec}")
187
+
188
+ # 如果是AV1或其他不兼容的编码,重编码为H.264
189
+ if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
190
+ print(f"检测到不兼容的编码格式 ({codec}),开始重编码...")
191
+ reencoded_path = reencode_video_to_h264(video_path, quality='fast')
192
+
193
+ # 检查重编码后的文件是否存在且大小合理
194
+ if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
195
+ return reencoded_path
196
+ else:
197
+ print("重编码失败,使用原始文件")
198
+ return video_path
199
+ else:
200
+ print(f"视频编码 ({codec}) 兼容,无需重编码")
201
+ return video_path
202
+
203
+
204
  def load_remote_dataset(repo_id: str,
205
  episode_id: int = 0,
206
  video_keys: Optional[list] = None,
207
  download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
208
  loader = RemoteDatasetLoader(repo_id)
209
+ video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir)
210
+
211
+ # 处理视频兼容性
212
+ processed_video_paths = []
213
+ for video_path in video_paths:
214
+ processed_path = process_video_for_compatibility(video_path)
215
+ processed_video_paths.append(processed_path)
216
+
217
+ return processed_video_paths, df
218
 
219
 
 
 
 
 
 
220
  # ------------------ 加载数据 ------------------
221
+ print("正在加载数据集...")
222
+ video_paths, data_df = load_remote_dataset(
223
+ repo_id="zijian2022/sortingtest",
224
+ episode_id=0,
225
+ download_dir="./downloaded_videos"
226
+ )
227
+
228
  df = data_df
229
  columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
230
  timestamps = df["timestamp"].values
 
233
  action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
234
 
235
  # ------------------ 视频路径 ------------------
 
 
236
  video_path_1 = video_paths[0]
237
  video_path_2 = video_paths[1]
238
+
239
+ print(f"视频路径1: {video_path_1}")
240
+ print(f"视频路径2: {video_path_2}")
241
+
242
  # ------------------ Dash 初始化 ------------------
243
  app = dash.Dash(__name__)
244
  server = app.server
 
248
 
249
  # ------------------ 视频帧提取函数 ------------------
250
  def get_video_frame(video_path, time_in_seconds):
251
+ """从视频中提取指定时间的帧"""
252
+ try:
253
+ cap = cv2.VideoCapture(video_path)
254
+ if not cap.isOpened():
255
+ print(f"❌ 无法打开视频: {video_path}")
256
+ return None
257
+
258
+ fps = cap.get(cv2.CAP_PROP_FPS)
259
+ if fps <= 0:
260
+ print(f"❌ 无法获取视频帧率: {video_path}")
261
+ cap.release()
262
+ return None
263
+
264
+ frame_num = int(time_in_seconds * fps)
265
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
266
+ success, frame = cap.read()
267
+ cap.release()
268
+
269
+ if success and frame is not None:
270
+ # 调整图像大小以减少数据传输
271
+ height, width = frame.shape[:2]
272
+ if width > 640: # 如果宽度大于640,则缩放
273
+ new_width = 640
274
+ new_height = int(height * (new_width / width))
275
+ frame = cv2.resize(frame, (new_width, new_height))
276
+
277
+ # 编码为JPEG
278
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85] # 85%质量
279
+ _, buffer = cv2.imencode('.jpg', frame, encode_param)
280
+ encoded = base64.b64encode(buffer).decode('utf-8')
281
+ return f"data:image/jpeg;base64,{encoded}"
282
+ else:
283
+ print(f"❌ 无法读取帧: {video_path}, 时间: {time_in_seconds}s")
284
+ return None
285
+
286
+ except Exception as e:
287
+ print(f"❌ 提取视频帧异常: {e}")
288
  return None
289
 
290
  def find_intervals(mask):
 
372
  return shadows_in_range
373
 
374
  # 预计算所有关节的阴影信息
375
+ print("正在预计算阴影信息...")
376
  for joint in columns:
377
  all_shadows[joint] = get_shadow_info(joint)
378
 
 
466
  ], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
467
  ], style={"marginBottom": "15px"}))
468
 
 
 
 
 
469
  # 设置 layout
470
+ app.layout = html.Div([
471
+ html.H1("机器人数据可视化 - 视频兼容性优化", style={"textAlign": "center", "marginBottom": "20px"}),
472
+ html.Div(rows)
473
+ ])
474
 
475
  # ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
476
  @app.callback(
 
543
  return [no_update] * 6
544
 
545
  # ------------------ 回调:监听 hoverData 更新视频帧 ------------------
 
 
546
  @app.callback(
547
  [Output(f"video1-{i}", "src") for i in range(6)] +
548
+ [Output(f"video2-{i}", "src") for i in range(6)],
549
+ [Input(f"graph-{i}", "hoverData") for i in range(6)]
 
 
 
550
  )
551
  def update_video_frames(*args):
552
+ hover_datas = args
 
 
553
 
554
  # 获取触发回调的上下文
555
  ctx = dash.callback_context
 
571
  frame1 = get_video_frame(video_path_1, hover_time)
572
  frame2 = get_video_frame(video_path_2, hover_time)
573
 
 
 
 
574
  # 如果成功获取帧,返回所有视频的帧
575
  if frame1 and frame2:
576
+ return [frame1]*6 + [frame2]*6
577
  except Exception as e:
578
  print(f"处理hover数据异常: {e}")
579
+
580
+ return [no_update]*12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  except Exception as e:
583
  print(f"update_video_frames回调函数异常: {e}")
584
+ return [no_update]*12
585
 
586
  # ------------------ 启动应用 ------------------
587
  if __name__ == "__main__":
588
+ print("应用启动中...")
589
+ app.run(debug=True, host='0.0.0.0', port=8050)