Spaces:
Sleeping
Sleeping
# ------------------ 导入库 ------------------ | |
import dash | |
from dash import dcc, html, Input, Output, State, callback_context, no_update | |
import plotly.graph_objects as go | |
import pandas as pd | |
import numpy as np | |
import cv2 | |
import base64 | |
from scipy.ndimage import gaussian_filter1d | |
import requests | |
import json | |
import tempfile | |
import os | |
from pathlib import Path | |
from typing import Tuple, Optional | |
from urllib.parse import urljoin | |
# ------------------ 下载数据 ------------------ | |
class RemoteDatasetLoader: | |
"""从 Hugging Face Hub 远程加载数据集的类""" | |
def __init__(self, repo_id: str, timeout: int = 30): | |
self.repo_id = repo_id | |
self.timeout = timeout | |
self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" | |
def _get_dataset_info(self) -> dict: | |
info_url = urljoin(self.base_url, "meta/info.json") | |
response = requests.get(info_url, timeout=self.timeout) | |
response.raise_for_status() | |
return response.json() | |
def _get_episode_info(self, episode_id: int) -> dict: | |
episodes_url = urljoin(self.base_url, "meta/episodes.jsonl") | |
response = requests.get(episodes_url, timeout=self.timeout) | |
response.raise_for_status() | |
episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()] | |
for episode in episodes: | |
if episode.get("episode_index") == episode_id: | |
return episode | |
raise ValueError(f"Episode {episode_id} not found") | |
def _download_video(self, video_url: str, save_path: str) -> str: | |
response = requests.get(video_url, timeout=self.timeout, stream=True) | |
response.raise_for_status() | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
with open(save_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return save_path | |
def load_episode_data(self, episode_id: int, | |
video_keys: Optional[list] = None, | |
download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]: | |
dataset_info = self._get_dataset_info() | |
episode_info = self._get_episode_info(episode_id) | |
if download_dir is None: | |
download_dir = tempfile.mkdtemp(prefix="lerobot_videos_") | |
if video_keys is None: | |
video_keys = [key for key, feature in dataset_info["features"].items() | |
if feature["dtype"] == "video"] | |
video_keys = video_keys[:2] | |
video_paths = [] | |
chunks_size = dataset_info.get("chunks_size", 1000) | |
for i, video_key in enumerate(video_keys): | |
video_url = self.base_url + dataset_info["video_path"].format( | |
episode_chunk=episode_id // chunks_size, | |
video_key=video_key, | |
episode_index=episode_id | |
) | |
video_filename = f"episode_{episode_id}_{video_key}.mp4" | |
local_path = os.path.join(download_dir, video_filename) | |
try: | |
downloaded_path = self._download_video(video_url, local_path) | |
video_paths.append(downloaded_path) | |
print(f"Downloaded video {i+1}: {downloaded_path}") | |
except Exception as e: | |
print(f"Failed to download video {video_key}: {e}") | |
video_paths.append(video_url) | |
data_url = self.base_url + dataset_info["data_path"].format( | |
episode_chunk=episode_id // chunks_size, | |
episode_index=episode_id | |
) | |
try: | |
df = pd.read_parquet(data_url) | |
print(f"Loaded data with {len(df)} rows and {len(df.columns)} columns") | |
except Exception as e: | |
print(f"Failed to load data: {e}") | |
df = pd.DataFrame() | |
return video_paths, df | |
def load_remote_dataset(repo_id: str, | |
episode_id: int = 0, | |
video_keys: Optional[list] = None, | |
download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]: | |
loader = RemoteDatasetLoader(repo_id) | |
return loader.load_episode_data(episode_id, video_keys, download_dir) | |
video_paths, data_df = load_remote_dataset( | |
repo_id="zijian2022/sortingtest", | |
episode_id=0, | |
download_dir="./downloaded_videos" | |
) | |
# ------------------ 加载数据 ------------------ | |
#df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet") | |
df = data_df | |
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"] | |
timestamps = df["timestamp"].values | |
delta_t = np.diff(timestamps) | |
time_for_plot = timestamps[1:] | |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns) | |
# ------------------ 视频路径 ------------------ | |
#video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4" | |
#video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4" | |
video_path_1 = video_paths[0] | |
video_path_2 = video_paths[1] | |
# ------------------ Dash 初始化 ------------------ | |
app = dash.Dash(__name__) | |
server = app.server | |
# ------------------ 全局变量存储阴影信息 ------------------ | |
all_shadows = {} # 存储所有关节的阴影信息 | |
# ------------------ 视频帧提取函数 ------------------ | |
def get_video_frame(video_path, time_in_seconds): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"❌ 无法打开视频: {video_path}") | |
return None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_num = int(time_in_seconds * fps) | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
success, frame = cap.read() | |
cap.release() | |
if success: | |
_, buffer = cv2.imencode('.jpg', frame) | |
encoded = base64.b64encode(buffer).decode('utf-8') | |
return f"data:image/jpeg;base64,{encoded}" | |
else: | |
return None | |
def find_intervals(mask): | |
intervals = [] | |
start = None | |
for i, val in enumerate(mask): | |
if val and start is None: | |
start = i | |
elif not val and start is not None: | |
intervals.append((start, i - 1)) | |
start = None | |
if start is not None: | |
intervals.append((start, len(mask) - 1)) | |
return intervals | |
def get_shadow_info(joint_name): | |
"""获取特定关节的所有红色阴影信息""" | |
angles = action_df[joint_name].values | |
velocity = np.diff(angles) / delta_t | |
smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
# 参数 | |
vel_threshold = 0.5 | |
highlight_width = 3 | |
k = 2 | |
shadows = [] | |
# 低速区间阴影 | |
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold | |
low_speed_intervals = find_intervals(low_speed_mask) | |
for start, end in low_speed_intervals: | |
if end - start + 1 <= k: | |
shadows.append({ | |
'type': 'low_speed', | |
'start_time': time_for_plot[start], | |
'end_time': time_for_plot[end], | |
'start_idx': start, | |
'end_idx': end | |
}) | |
# 最大值阴影 | |
max_idx = np.argmax(smoothed_angle) | |
s_max = max(0, max_idx - highlight_width) | |
e_max = min(len(time_for_plot) - 1, max_idx + highlight_width) | |
shadows.append({ | |
'type': 'max_value', | |
'start_time': time_for_plot[s_max], | |
'end_time': time_for_plot[e_max], | |
'start_idx': s_max, | |
'end_idx': e_max | |
}) | |
# 最小值阴影 | |
min_idx = np.argmin(smoothed_angle) | |
s_min = max(0, min_idx - highlight_width) | |
e_min = min(len(time_for_plot) - 1, min_idx + highlight_width) | |
shadows.append({ | |
'type': 'min_value', | |
'start_time': time_for_plot[s_min], | |
'end_time': time_for_plot[e_min], | |
'start_idx': s_min, | |
'end_idx': e_min | |
}) | |
return shadows | |
def is_hover_in_shadow(hover_time, shadows): | |
"""检查hover时间是否在任何阴影内""" | |
for shadow in shadows: | |
if shadow['start_time'] <= hover_time <= shadow['end_time']: | |
return True | |
return False | |
def find_shadows_in_range(shadows, start_time, end_time): | |
"""找到指定时间范围内的所有阴影""" | |
shadows_in_range = [] | |
for shadow in shadows: | |
# 检查阴影是否与指定范围有重叠 | |
if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time): | |
shadows_in_range.append(shadow) | |
return shadows_in_range | |
# 预计算所有关节的阴影信息 | |
for joint in columns: | |
all_shadows[joint] = get_shadow_info(joint) | |
# ------------------ 图表生成函数 ------------------ | |
def generate_joint_graph(joint_name, idx, highlighted_shadows=None): | |
angles = action_df[joint_name].values | |
velocity = np.diff(angles) / delta_t | |
smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
# 参数 | |
vel_threshold = 0.5 | |
highlight_width = 3 | |
k = 2 | |
# 找低速区间 | |
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold | |
low_speed_intervals = find_intervals(low_speed_mask) | |
# 找最大最小点 | |
max_idx = np.argmax(smoothed_angle) | |
min_idx = np.argmin(smoothed_angle) | |
shapes = [] | |
# 获取当前关节的阴影信息 | |
current_shadows = all_shadows[joint_name] | |
# 正常的红色阴影 | |
for shadow in current_shadows: | |
is_highlighted = False | |
if highlighted_shadows: | |
for h_shadow in highlighted_shadows: | |
if (shadow['start_time'] == h_shadow['start_time'] and | |
shadow['end_time'] == h_shadow['end_time']): | |
is_highlighted = True | |
break | |
color = "blue" if is_highlighted else "red" | |
opacity = 0.6 if is_highlighted else 0.3 | |
shapes.append({ | |
"type": "rect", | |
"xref": "x", | |
"yref": "paper", | |
"x0": shadow['start_time'], | |
"x1": shadow['end_time'], | |
"y0": 0, | |
"y1": 1, | |
"fillcolor": color, | |
"opacity": opacity, | |
"line": {"width": 0} | |
}) | |
return dcc.Graph( | |
id=f"graph-{idx}", | |
figure={ | |
"data": [ | |
go.Scatter( | |
x=time_for_plot, | |
y=smoothed_angle, | |
name="Angle", | |
line=dict(color='orange') | |
) | |
], | |
"layout": go.Layout( | |
title=joint_name, | |
xaxis={"title": "Time (s)"}, | |
yaxis={"title": "Angle (deg)"}, | |
shapes=shapes, | |
hovermode="x unified", | |
height=250, | |
margin=dict(t=30, b=30, l=50, r=50), | |
showlegend=False, | |
) | |
}, | |
style={"height": "250px"} | |
) | |
# ------------------ 布局 ------------------ | |
rows = [] | |
# 关节图 + 双视频帧 | |
for i, joint in enumerate(columns): | |
rows.append(html.Div([ | |
html.Div(generate_joint_graph(joint, i), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}), | |
html.Div([ | |
html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}), | |
html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}) | |
], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"}) | |
], style={"marginBottom": "15px"})) | |
# 添加定时器和存储组件 | |
rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0)) | |
rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0})) | |
# 设置 layout | |
app.layout = html.Div(rows) | |
# ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------ | |
def update_shadow_highlighting(*args): | |
hover_datas = args[:6] | |
current_figures = args[6:] | |
ctx = dash.callback_context | |
# 检查是否有hover触发 | |
if not ctx.triggered: | |
return [no_update] * 6 | |
trigger_id = ctx.triggered[0]['prop_id'] | |
# 如果不是hover触发,不更新 | |
if 'hoverData' not in trigger_id: | |
return [no_update] * 6 | |
# 提取触发的图表索引 | |
graph_idx = int(trigger_id.split('-')[1].split('.')[0]) | |
hover_data = hover_datas[graph_idx] | |
# 如果没有hover数据,恢复正常状态 | |
if not hover_data or "points" not in hover_data or len(hover_data["points"]) == 0: | |
updated_figures = [] | |
for i, joint in enumerate(columns): | |
updated_figures.append(generate_joint_graph(joint, i).figure) | |
return updated_figures | |
try: | |
hover_time = float(hover_data["points"][0]["x"]) | |
triggered_joint = columns[graph_idx] | |
# 检查hover是否在红色阴影内 | |
if not is_hover_in_shadow(hover_time, all_shadows[triggered_joint]): | |
# 如果不在阴影内,恢复正常状态 | |
updated_figures = [] | |
for i, joint in enumerate(columns): | |
updated_figures.append(generate_joint_graph(joint, i).figure) | |
return updated_figures | |
# 找到hover时间对应的时间戳索引 | |
hover_idx = np.searchsorted(time_for_plot, hover_time) | |
# 计算前后10个时间戳的范围 | |
start_idx = max(0, hover_idx - 20) | |
end_idx = min(len(time_for_plot) - 1, hover_idx + 20) | |
start_time = time_for_plot[start_idx] | |
end_time = time_for_plot[end_idx] | |
# 为每个关节生成更新的图表 | |
updated_figures = [] | |
for i, joint in enumerate(columns): | |
# 找到该关节在指定时间范围内的阴影 | |
shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time) | |
# 生成带有高亮的图表 | |
updated_figure = generate_joint_graph(joint, i, shadows_in_range) | |
updated_figures.append(updated_figure.figure) | |
return updated_figures | |
except Exception as e: | |
print(f"处理阴影高亮异常: {e}") | |
return [no_update] * 6 | |
# ------------------ 回调:监听 hoverData 更新视频帧 ------------------ | |
video_duration = timestamps[-1] - timestamps[0] | |
def update_video_frames(*args): | |
hover_datas = args[:-2] | |
interval_count = args[-2] | |
hover_state = args[-1] | |
# 获取触发回调的上下文 | |
ctx = dash.callback_context | |
try: | |
# 检查是否有hover触发了回调 | |
if ctx.triggered: | |
trigger_id = ctx.triggered[0]['prop_id'] | |
# 如果是图表hover触发的 | |
if 'hoverData' in trigger_id: | |
# 从trigger_id中提取图表索引 | |
graph_idx = int(trigger_id.split('-')[1].split('.')[0]) | |
hover_data = hover_datas[graph_idx] | |
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0: | |
try: | |
hover_time = float(hover_data["points"][0]["x"]) | |
frame1 = get_video_frame(video_path_1, hover_time) | |
frame2 = get_video_frame(video_path_2, hover_time) | |
# 更新hover状态为活跃 | |
new_hover_state = {"active": True, "last_update": interval_count} | |
# 如果成功获取帧,返回所有视频的帧 | |
if frame1 and frame2: | |
return [frame1]*6 + [frame2]*6 + [new_hover_state] | |
except Exception as e: | |
print(f"处理hover数据异常: {e}") | |
# 如果是interval触发的 | |
if 'video-playback-interval' in trigger_id: | |
# 检查hover状态是否过期(超过3个interval周期没有更新) | |
hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3 | |
if not hover_state.get("active", False) or hover_expired: | |
# 没有hover或hover已过期时才自动播放 | |
t = timestamps[0] + (interval_count * 0.3) % video_duration | |
frame1 = get_video_frame(video_path_1, t) | |
frame2 = get_video_frame(video_path_2, t) | |
# 更新hover状态为非活跃 | |
new_hover_state = {"active": False, "last_update": interval_count} | |
if frame1 and frame2: | |
return [frame1]*6 + [frame2]*6 + [new_hover_state] | |
else: | |
return [no_update]*12 + [new_hover_state] | |
else: | |
# hover仍然活跃时,暂停自动播放 | |
return [no_update]*12 + [hover_state] | |
return [no_update]*12 + [hover_state] | |
except Exception as e: | |
print(f"update_video_frames回调函数异常: {e}") | |
return [no_update]*12 + [hover_state] | |
# ------------------ 启动应用 ------------------ | |
if __name__ == "__main__": | |
app.run(debug=True) |