Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,12 @@ import os
|
|
14 |
from pathlib import Path
|
15 |
from typing import Tuple, Optional
|
16 |
from urllib.parse import urljoin
|
|
|
|
|
|
|
|
|
17 |
# ------------------ 下载数据 ------------------
|
|
|
18 |
|
19 |
class RemoteDatasetLoader:
|
20 |
"""从 Hugging Face Hub 远程加载数据集的类"""
|
@@ -96,21 +101,130 @@ class RemoteDatasetLoader:
|
|
96 |
return video_paths, df
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def load_remote_dataset(repo_id: str,
|
100 |
episode_id: int = 0,
|
101 |
video_keys: Optional[list] = None,
|
102 |
download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
|
103 |
loader = RemoteDatasetLoader(repo_id)
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
-
video_paths, data_df = load_remote_dataset(
|
108 |
-
repo_id="zijian2022/sortingtest",
|
109 |
-
episode_id=0,
|
110 |
-
download_dir="./downloaded_videos"
|
111 |
-
)
|
112 |
# ------------------ 加载数据 ------------------
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
df = data_df
|
115 |
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
|
116 |
timestamps = df["timestamp"].values
|
@@ -119,10 +233,12 @@ time_for_plot = timestamps[1:]
|
|
119 |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
|
120 |
|
121 |
# ------------------ 视频路径 ------------------
|
122 |
-
#video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
|
123 |
-
#video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
|
124 |
video_path_1 = video_paths[0]
|
125 |
video_path_2 = video_paths[1]
|
|
|
|
|
|
|
|
|
126 |
# ------------------ Dash 初始化 ------------------
|
127 |
app = dash.Dash(__name__)
|
128 |
server = app.server
|
@@ -132,20 +248,43 @@ all_shadows = {} # 存储所有关节的阴影信息
|
|
132 |
|
133 |
# ------------------ 视频帧提取函数 ------------------
|
134 |
def get_video_frame(video_path, time_in_seconds):
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
return None
|
150 |
|
151 |
def find_intervals(mask):
|
@@ -233,6 +372,7 @@ def find_shadows_in_range(shadows, start_time, end_time):
|
|
233 |
return shadows_in_range
|
234 |
|
235 |
# 预计算所有关节的阴影信息
|
|
|
236 |
for joint in columns:
|
237 |
all_shadows[joint] = get_shadow_info(joint)
|
238 |
|
@@ -326,12 +466,11 @@ for i, joint in enumerate(columns):
|
|
326 |
], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
|
327 |
], style={"marginBottom": "15px"}))
|
328 |
|
329 |
-
# 添加定时器和存储组件
|
330 |
-
rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0))
|
331 |
-
rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0}))
|
332 |
-
|
333 |
# 设置 layout
|
334 |
-
app.layout = html.Div(
|
|
|
|
|
|
|
335 |
|
336 |
# ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
|
337 |
@app.callback(
|
@@ -404,20 +543,13 @@ def update_shadow_highlighting(*args):
|
|
404 |
return [no_update] * 6
|
405 |
|
406 |
# ------------------ 回调:监听 hoverData 更新视频帧 ------------------
|
407 |
-
video_duration = timestamps[-1] - timestamps[0]
|
408 |
-
|
409 |
@app.callback(
|
410 |
[Output(f"video1-{i}", "src") for i in range(6)] +
|
411 |
-
[Output(f"video2-{i}", "src") for i in range(6)]
|
412 |
-
[
|
413 |
-
[Input(f"graph-{i}", "hoverData") for i in range(6)] +
|
414 |
-
[Input("video-playback-interval", "n_intervals")],
|
415 |
-
[State("hover-state-store", "data")]
|
416 |
)
|
417 |
def update_video_frames(*args):
|
418 |
-
hover_datas = args
|
419 |
-
interval_count = args[-2]
|
420 |
-
hover_state = args[-1]
|
421 |
|
422 |
# 获取触发回调的上下文
|
423 |
ctx = dash.callback_context
|
@@ -439,43 +571,19 @@ def update_video_frames(*args):
|
|
439 |
frame1 = get_video_frame(video_path_1, hover_time)
|
440 |
frame2 = get_video_frame(video_path_2, hover_time)
|
441 |
|
442 |
-
# 更新hover状态为活跃
|
443 |
-
new_hover_state = {"active": True, "last_update": interval_count}
|
444 |
-
|
445 |
# 如果成功获取帧,返回所有视频的帧
|
446 |
if frame1 and frame2:
|
447 |
-
return [frame1]*6 + [frame2]*6
|
448 |
except Exception as e:
|
449 |
print(f"处理hover数据异常: {e}")
|
450 |
-
|
451 |
-
|
452 |
-
if 'video-playback-interval' in trigger_id:
|
453 |
-
# 检查hover状态是否过期(超过3个interval周期没有更新)
|
454 |
-
hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3
|
455 |
-
|
456 |
-
if not hover_state.get("active", False) or hover_expired:
|
457 |
-
# 没有hover或hover已过期时才自动播放
|
458 |
-
t = timestamps[0] + (interval_count * 0.3) % video_duration
|
459 |
-
frame1 = get_video_frame(video_path_1, t)
|
460 |
-
frame2 = get_video_frame(video_path_2, t)
|
461 |
-
|
462 |
-
# 更新hover状态为非活跃
|
463 |
-
new_hover_state = {"active": False, "last_update": interval_count}
|
464 |
-
|
465 |
-
if frame1 and frame2:
|
466 |
-
return [frame1]*6 + [frame2]*6 + [new_hover_state]
|
467 |
-
else:
|
468 |
-
return [no_update]*12 + [new_hover_state]
|
469 |
-
else:
|
470 |
-
# hover仍然活跃时,暂停自动播放
|
471 |
-
return [no_update]*12 + [hover_state]
|
472 |
-
|
473 |
-
return [no_update]*12 + [hover_state]
|
474 |
|
475 |
except Exception as e:
|
476 |
print(f"update_video_frames回调函数异常: {e}")
|
477 |
-
return [no_update]*12
|
478 |
|
479 |
# ------------------ 启动应用 ------------------
|
480 |
if __name__ == "__main__":
|
481 |
-
|
|
|
|
14 |
from pathlib import Path
|
15 |
from typing import Tuple, Optional
|
16 |
from urllib.parse import urljoin
|
17 |
+
import subprocess
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
|
21 |
# ------------------ 下载数据 ------------------
|
22 |
+
DOWNLOAD_DIR = tempfile.mkdtemp()
|
23 |
|
24 |
class RemoteDatasetLoader:
|
25 |
"""从 Hugging Face Hub 远程加载数据集的类"""
|
|
|
101 |
return video_paths, df
|
102 |
|
103 |
|
104 |
+
# ------------------ 视频重编码函数 ------------------
|
105 |
+
def check_ffmpeg_available():
|
106 |
+
"""检查ffmpeg是否可用"""
|
107 |
+
try:
|
108 |
+
result = subprocess.run(['ffmpeg', '-version'],
|
109 |
+
capture_output=True, text=True, timeout=5)
|
110 |
+
return result.returncode == 0
|
111 |
+
except (subprocess.TimeoutExpired, FileNotFoundError):
|
112 |
+
return False
|
113 |
+
|
114 |
+
def get_video_codec_info(video_path):
|
115 |
+
"""获取视频编码信息"""
|
116 |
+
try:
|
117 |
+
result = subprocess.run([
|
118 |
+
'ffprobe', '-v', 'quiet', '-print_format', 'json',
|
119 |
+
'-show_streams', video_path
|
120 |
+
], capture_output=True, text=True, timeout=10)
|
121 |
+
|
122 |
+
if result.returncode == 0:
|
123 |
+
info = json.loads(result.stdout)
|
124 |
+
for stream in info.get('streams', []):
|
125 |
+
if stream.get('codec_type') == 'video':
|
126 |
+
return stream.get('codec_name', 'unknown')
|
127 |
+
except Exception as e:
|
128 |
+
print(f"获取视频编码信息失败: {e}")
|
129 |
+
|
130 |
+
return 'unknown'
|
131 |
+
|
132 |
+
def reencode_video_to_h264(input_path, output_path=None, quality='medium'):
|
133 |
+
"""将视频重编码为H.264格式"""
|
134 |
+
if output_path is None:
|
135 |
+
base_name = os.path.splitext(input_path)[0]
|
136 |
+
output_path = f"{base_name}_h264.mp4"
|
137 |
+
|
138 |
+
# 根据质量选择参数
|
139 |
+
quality_params = {
|
140 |
+
'fast': ['-preset', 'ultrafast', '-crf', '28'],
|
141 |
+
'medium': ['-preset', 'medium', '-crf', '23'],
|
142 |
+
'high': ['-preset', 'slow', '-crf', '18']
|
143 |
+
}
|
144 |
+
|
145 |
+
params = quality_params.get(quality, quality_params['medium'])
|
146 |
+
|
147 |
+
try:
|
148 |
+
cmd = [
|
149 |
+
'ffmpeg', '-i', input_path,
|
150 |
+
'-c:v', 'libx264', # 使用H.264编码器
|
151 |
+
'-c:a', 'aac', # 音频编码器
|
152 |
+
'-movflags', '+faststart', # 优化网络播放
|
153 |
+
'-y', # 覆盖输出文件
|
154 |
+
] + params + [output_path]
|
155 |
+
|
156 |
+
print(f"重编码视频: {input_path} -> {output_path}")
|
157 |
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
158 |
+
|
159 |
+
if result.returncode == 0:
|
160 |
+
print(f"重编码成功: {output_path}")
|
161 |
+
return output_path
|
162 |
+
else:
|
163 |
+
print(f"重编码失败: {result.stderr}")
|
164 |
+
return input_path
|
165 |
+
|
166 |
+
except subprocess.TimeoutExpired:
|
167 |
+
print("重编码超时")
|
168 |
+
return input_path
|
169 |
+
except Exception as e:
|
170 |
+
print(f"重编码异常: {e}")
|
171 |
+
return input_path
|
172 |
+
|
173 |
+
def process_video_for_compatibility(video_path):
|
174 |
+
"""处理视频以确保兼容性"""
|
175 |
+
if not os.path.exists(video_path):
|
176 |
+
print(f"视频文件不存在: {video_path}")
|
177 |
+
return video_path
|
178 |
+
|
179 |
+
# 检查ffmpeg是否可用
|
180 |
+
if not check_ffmpeg_available():
|
181 |
+
print("ffmpeg不可用,跳过重编码")
|
182 |
+
return video_path
|
183 |
+
|
184 |
+
# 获取视频编码信息
|
185 |
+
codec = get_video_codec_info(video_path)
|
186 |
+
print(f"视频编码格式: {codec}")
|
187 |
+
|
188 |
+
# 如果是AV1或其他不兼容的编码,重编码为H.264
|
189 |
+
if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown':
|
190 |
+
print(f"检测到不兼容的编码格式 ({codec}),开始重编码...")
|
191 |
+
reencoded_path = reencode_video_to_h264(video_path, quality='fast')
|
192 |
+
|
193 |
+
# 检查重编码后的文件是否存在且大小合理
|
194 |
+
if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024:
|
195 |
+
return reencoded_path
|
196 |
+
else:
|
197 |
+
print("重编码失败,使用原始文件")
|
198 |
+
return video_path
|
199 |
+
else:
|
200 |
+
print(f"视频编码 ({codec}) 兼容,无需重编码")
|
201 |
+
return video_path
|
202 |
+
|
203 |
+
|
204 |
def load_remote_dataset(repo_id: str,
|
205 |
episode_id: int = 0,
|
206 |
video_keys: Optional[list] = None,
|
207 |
download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
|
208 |
loader = RemoteDatasetLoader(repo_id)
|
209 |
+
video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir)
|
210 |
+
|
211 |
+
# 处理视频兼容性
|
212 |
+
processed_video_paths = []
|
213 |
+
for video_path in video_paths:
|
214 |
+
processed_path = process_video_for_compatibility(video_path)
|
215 |
+
processed_video_paths.append(processed_path)
|
216 |
+
|
217 |
+
return processed_video_paths, df
|
218 |
|
219 |
|
|
|
|
|
|
|
|
|
|
|
220 |
# ------------------ 加载数据 ------------------
|
221 |
+
print("正在加载数据集...")
|
222 |
+
video_paths, data_df = load_remote_dataset(
|
223 |
+
repo_id="zijian2022/sortingtest",
|
224 |
+
episode_id=0,
|
225 |
+
download_dir="./downloaded_videos"
|
226 |
+
)
|
227 |
+
|
228 |
df = data_df
|
229 |
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
|
230 |
timestamps = df["timestamp"].values
|
|
|
233 |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
|
234 |
|
235 |
# ------------------ 视频路径 ------------------
|
|
|
|
|
236 |
video_path_1 = video_paths[0]
|
237 |
video_path_2 = video_paths[1]
|
238 |
+
|
239 |
+
print(f"视频路径1: {video_path_1}")
|
240 |
+
print(f"视频路径2: {video_path_2}")
|
241 |
+
|
242 |
# ------------------ Dash 初始化 ------------------
|
243 |
app = dash.Dash(__name__)
|
244 |
server = app.server
|
|
|
248 |
|
249 |
# ------------------ 视频帧提取函数 ------------------
|
250 |
def get_video_frame(video_path, time_in_seconds):
|
251 |
+
"""从视频中提取指定时间的帧"""
|
252 |
+
try:
|
253 |
+
cap = cv2.VideoCapture(video_path)
|
254 |
+
if not cap.isOpened():
|
255 |
+
print(f"❌ 无法打开视频: {video_path}")
|
256 |
+
return None
|
257 |
+
|
258 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
259 |
+
if fps <= 0:
|
260 |
+
print(f"❌ 无法获取视频帧率: {video_path}")
|
261 |
+
cap.release()
|
262 |
+
return None
|
263 |
+
|
264 |
+
frame_num = int(time_in_seconds * fps)
|
265 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
266 |
+
success, frame = cap.read()
|
267 |
+
cap.release()
|
268 |
+
|
269 |
+
if success and frame is not None:
|
270 |
+
# 调整图像大小以减少数据传输
|
271 |
+
height, width = frame.shape[:2]
|
272 |
+
if width > 640: # 如果宽度大于640,则缩放
|
273 |
+
new_width = 640
|
274 |
+
new_height = int(height * (new_width / width))
|
275 |
+
frame = cv2.resize(frame, (new_width, new_height))
|
276 |
+
|
277 |
+
# 编码为JPEG
|
278 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85] # 85%质量
|
279 |
+
_, buffer = cv2.imencode('.jpg', frame, encode_param)
|
280 |
+
encoded = base64.b64encode(buffer).decode('utf-8')
|
281 |
+
return f"data:image/jpeg;base64,{encoded}"
|
282 |
+
else:
|
283 |
+
print(f"❌ 无法读取帧: {video_path}, 时间: {time_in_seconds}s")
|
284 |
+
return None
|
285 |
+
|
286 |
+
except Exception as e:
|
287 |
+
print(f"❌ 提取视频帧异常: {e}")
|
288 |
return None
|
289 |
|
290 |
def find_intervals(mask):
|
|
|
372 |
return shadows_in_range
|
373 |
|
374 |
# 预计算所有关节的阴影信息
|
375 |
+
print("正在预计算阴影信息...")
|
376 |
for joint in columns:
|
377 |
all_shadows[joint] = get_shadow_info(joint)
|
378 |
|
|
|
466 |
], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
|
467 |
], style={"marginBottom": "15px"}))
|
468 |
|
|
|
|
|
|
|
|
|
469 |
# 设置 layout
|
470 |
+
app.layout = html.Div([
|
471 |
+
html.H1("机器人数据可视化 - 视频兼容性优化", style={"textAlign": "center", "marginBottom": "20px"}),
|
472 |
+
html.Div(rows)
|
473 |
+
])
|
474 |
|
475 |
# ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
|
476 |
@app.callback(
|
|
|
543 |
return [no_update] * 6
|
544 |
|
545 |
# ------------------ 回调:监听 hoverData 更新视频帧 ------------------
|
|
|
|
|
546 |
@app.callback(
|
547 |
[Output(f"video1-{i}", "src") for i in range(6)] +
|
548 |
+
[Output(f"video2-{i}", "src") for i in range(6)],
|
549 |
+
[Input(f"graph-{i}", "hoverData") for i in range(6)]
|
|
|
|
|
|
|
550 |
)
|
551 |
def update_video_frames(*args):
|
552 |
+
hover_datas = args
|
|
|
|
|
553 |
|
554 |
# 获取触发回调的上下文
|
555 |
ctx = dash.callback_context
|
|
|
571 |
frame1 = get_video_frame(video_path_1, hover_time)
|
572 |
frame2 = get_video_frame(video_path_2, hover_time)
|
573 |
|
|
|
|
|
|
|
574 |
# 如果成功获取帧,返回所有视频的帧
|
575 |
if frame1 and frame2:
|
576 |
+
return [frame1]*6 + [frame2]*6
|
577 |
except Exception as e:
|
578 |
print(f"处理hover数据异常: {e}")
|
579 |
+
|
580 |
+
return [no_update]*12
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
|
582 |
except Exception as e:
|
583 |
print(f"update_video_frames回调函数异常: {e}")
|
584 |
+
return [no_update]*12
|
585 |
|
586 |
# ------------------ 启动应用 ------------------
|
587 |
if __name__ == "__main__":
|
588 |
+
print("应用启动中...")
|
589 |
+
app.run(debug=True, host='0.0.0.0', port=8050)
|