Spaces:
Running
Running
# ------------------ 导入库 ------------------ | |
import dash | |
from dash import dcc, html, Input, Output, State, callback_context, no_update | |
import plotly.graph_objects as go | |
import pandas as pd | |
import numpy as np | |
import cv2 | |
import base64 | |
from scipy.ndimage import gaussian_filter1d | |
# ------------------ 加载数据 ------------------ | |
df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet") | |
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"] | |
timestamps = df["timestamp"].values | |
delta_t = np.diff(timestamps) | |
time_for_plot = timestamps[1:] | |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns) | |
# ------------------ 视频路径 ------------------ | |
video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4" | |
video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4" | |
# ------------------ Dash 初始化 ------------------ | |
app = dash.Dash(__name__) | |
server = app.server | |
# ------------------ 全局变量存储阴影信息 ------------------ | |
all_shadows = {} # 存储所有关节的阴影信息 | |
# ------------------ 视频帧提取函数 ------------------ | |
def get_video_frame(video_path, time_in_seconds): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"❌ 无法打开视频: {video_path}") | |
return None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_num = int(time_in_seconds * fps) | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
success, frame = cap.read() | |
cap.release() | |
if success: | |
_, buffer = cv2.imencode('.jpg', frame) | |
encoded = base64.b64encode(buffer).decode('utf-8') | |
return f"data:image/jpeg;base64,{encoded}" | |
else: | |
return None | |
def find_intervals(mask): | |
intervals = [] | |
start = None | |
for i, val in enumerate(mask): | |
if val and start is None: | |
start = i | |
elif not val and start is not None: | |
intervals.append((start, i - 1)) | |
start = None | |
if start is not None: | |
intervals.append((start, len(mask) - 1)) | |
return intervals | |
def get_shadow_info(joint_name): | |
"""获取特定关节的所有红色阴影信息""" | |
angles = action_df[joint_name].values | |
velocity = np.diff(angles) / delta_t | |
smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
# 参数 | |
vel_threshold = 0.5 | |
highlight_width = 3 | |
k = 2 | |
shadows = [] | |
# 低速区间阴影 | |
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold | |
low_speed_intervals = find_intervals(low_speed_mask) | |
for start, end in low_speed_intervals: | |
if end - start + 1 <= k: | |
shadows.append({ | |
'type': 'low_speed', | |
'start_time': time_for_plot[start], | |
'end_time': time_for_plot[end], | |
'start_idx': start, | |
'end_idx': end | |
}) | |
# 最大值阴影 | |
max_idx = np.argmax(smoothed_angle) | |
s_max = max(0, max_idx - highlight_width) | |
e_max = min(len(time_for_plot) - 1, max_idx + highlight_width) | |
shadows.append({ | |
'type': 'max_value', | |
'start_time': time_for_plot[s_max], | |
'end_time': time_for_plot[e_max], | |
'start_idx': s_max, | |
'end_idx': e_max | |
}) | |
# 最小值阴影 | |
min_idx = np.argmin(smoothed_angle) | |
s_min = max(0, min_idx - highlight_width) | |
e_min = min(len(time_for_plot) - 1, min_idx + highlight_width) | |
shadows.append({ | |
'type': 'min_value', | |
'start_time': time_for_plot[s_min], | |
'end_time': time_for_plot[e_min], | |
'start_idx': s_min, | |
'end_idx': e_min | |
}) | |
return shadows | |
def is_hover_in_shadow(hover_time, shadows): | |
"""检查hover时间是否在任何阴影内""" | |
for shadow in shadows: | |
if shadow['start_time'] <= hover_time <= shadow['end_time']: | |
return True | |
return False | |
def find_shadows_in_range(shadows, start_time, end_time): | |
"""找到指定时间范围内的所有阴影""" | |
shadows_in_range = [] | |
for shadow in shadows: | |
# 检查阴影是否与指定范围有重叠 | |
if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time): | |
shadows_in_range.append(shadow) | |
return shadows_in_range | |
# 预计算所有关节的阴影信息 | |
for joint in columns: | |
all_shadows[joint] = get_shadow_info(joint) | |
# ------------------ 图表生成函数 ------------------ | |
def generate_joint_graph(joint_name, idx, highlighted_shadows=None): | |
angles = action_df[joint_name].values | |
velocity = np.diff(angles) / delta_t | |
smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
# 参数 | |
vel_threshold = 0.5 | |
highlight_width = 3 | |
k = 2 | |
# 找低速区间 | |
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold | |
low_speed_intervals = find_intervals(low_speed_mask) | |
# 找最大最小点 | |
max_idx = np.argmax(smoothed_angle) | |
min_idx = np.argmin(smoothed_angle) | |
shapes = [] | |
# 获取当前关节的阴影信息 | |
current_shadows = all_shadows[joint_name] | |
# 正常的红色阴影 | |
for shadow in current_shadows: | |
is_highlighted = False | |
if highlighted_shadows: | |
for h_shadow in highlighted_shadows: | |
if (shadow['start_time'] == h_shadow['start_time'] and | |
shadow['end_time'] == h_shadow['end_time']): | |
is_highlighted = True | |
break | |
color = "blue" if is_highlighted else "red" | |
opacity = 0.6 if is_highlighted else 0.3 | |
shapes.append({ | |
"type": "rect", | |
"xref": "x", | |
"yref": "paper", | |
"x0": shadow['start_time'], | |
"x1": shadow['end_time'], | |
"y0": 0, | |
"y1": 1, | |
"fillcolor": color, | |
"opacity": opacity, | |
"line": {"width": 0} | |
}) | |
return dcc.Graph( | |
id=f"graph-{idx}", | |
figure={ | |
"data": [ | |
go.Scatter( | |
x=time_for_plot, | |
y=smoothed_angle, | |
name="Angle", | |
line=dict(color='orange') | |
) | |
], | |
"layout": go.Layout( | |
title=joint_name, | |
xaxis={"title": "Time (s)"}, | |
yaxis={"title": "Angle (deg)"}, | |
shapes=shapes, | |
hovermode="x unified", | |
height=250, | |
margin=dict(t=30, b=30, l=50, r=50), | |
showlegend=False, | |
) | |
}, | |
style={"height": "250px"} | |
) | |
# ------------------ 布局 ------------------ | |
rows = [] | |
# 关节图 + 双视频帧 | |
for i, joint in enumerate(columns): | |
rows.append(html.Div([ | |
html.Div(generate_joint_graph(joint, i), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}), | |
html.Div([ | |
html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}), | |
html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}) | |
], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"}) | |
], style={"marginBottom": "15px"})) | |
# 添加定时器和存储组件 | |
rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0)) | |
rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0})) | |
# 设置 layout | |
app.layout = html.Div(rows) | |
# ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------ | |
def update_shadow_highlighting(*args): | |
hover_datas = args[:6] | |
current_figures = args[6:] | |
ctx = dash.callback_context | |
# 检查是否有hover触发 | |
if not ctx.triggered: | |
return [no_update] * 6 | |
trigger_id = ctx.triggered[0]['prop_id'] | |
# 如果不是hover触发,不更新 | |
if 'hoverData' not in trigger_id: | |
return [no_update] * 6 | |
# 提取触发的图表索引 | |
graph_idx = int(trigger_id.split('-')[1].split('.')[0]) | |
hover_data = hover_datas[graph_idx] | |
# 如果没有hover数据,恢复正常状态 | |
if not hover_data or "points" not in hover_data or len(hover_data["points"]) == 0: | |
updated_figures = [] | |
for i, joint in enumerate(columns): | |
updated_figures.append(generate_joint_graph(joint, i).figure) | |
return updated_figures | |
try: | |
hover_time = float(hover_data["points"][0]["x"]) | |
triggered_joint = columns[graph_idx] | |
# 检查hover是否在红色阴影内 | |
if not is_hover_in_shadow(hover_time, all_shadows[triggered_joint]): | |
# 如果不在阴影内,恢复正常状态 | |
updated_figures = [] | |
for i, joint in enumerate(columns): | |
updated_figures.append(generate_joint_graph(joint, i).figure) | |
return updated_figures | |
# 找到hover时间对应的时间戳索引 | |
hover_idx = np.searchsorted(time_for_plot, hover_time) | |
# 计算前后10个时间戳的范围 | |
start_idx = max(0, hover_idx - 20) | |
end_idx = min(len(time_for_plot) - 1, hover_idx + 20) | |
start_time = time_for_plot[start_idx] | |
end_time = time_for_plot[end_idx] | |
# 为每个关节生成更新的图表 | |
updated_figures = [] | |
for i, joint in enumerate(columns): | |
# 找到该关节在指定时间范围内的阴影 | |
shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time) | |
# 生成带有高亮的图表 | |
updated_figure = generate_joint_graph(joint, i, shadows_in_range) | |
updated_figures.append(updated_figure.figure) | |
return updated_figures | |
except Exception as e: | |
print(f"处理阴影高亮异常: {e}") | |
return [no_update] * 6 | |
# ------------------ 回调:监听 hoverData 更新视频帧 ------------------ | |
video_duration = timestamps[-1] - timestamps[0] | |
def update_video_frames(*args): | |
hover_datas = args[:-2] | |
interval_count = args[-2] | |
hover_state = args[-1] | |
# 获取触发回调的上下文 | |
ctx = dash.callback_context | |
try: | |
# 检查是否有hover触发了回调 | |
if ctx.triggered: | |
trigger_id = ctx.triggered[0]['prop_id'] | |
# 如果是图表hover触发的 | |
if 'hoverData' in trigger_id: | |
# 从trigger_id中提取图表索引 | |
graph_idx = int(trigger_id.split('-')[1].split('.')[0]) | |
hover_data = hover_datas[graph_idx] | |
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0: | |
try: | |
hover_time = float(hover_data["points"][0]["x"]) | |
frame1 = get_video_frame(video_path_1, hover_time) | |
frame2 = get_video_frame(video_path_2, hover_time) | |
# 更新hover状态为活跃 | |
new_hover_state = {"active": True, "last_update": interval_count} | |
# 如果成功获取帧,返回所有视频的帧 | |
if frame1 and frame2: | |
return [frame1]*6 + [frame2]*6 + [new_hover_state] | |
except Exception as e: | |
print(f"处理hover数据异常: {e}") | |
# 如果是interval触发的 | |
if 'video-playback-interval' in trigger_id: | |
# 检查hover状态是否过期(超过3个interval周期没有更新) | |
hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3 | |
if not hover_state.get("active", False) or hover_expired: | |
# 没有hover或hover已过期时才自动播放 | |
t = timestamps[0] + (interval_count * 0.3) % video_duration | |
frame1 = get_video_frame(video_path_1, t) | |
frame2 = get_video_frame(video_path_2, t) | |
# 更新hover状态为非活跃 | |
new_hover_state = {"active": False, "last_update": interval_count} | |
if frame1 and frame2: | |
return [frame1]*6 + [frame2]*6 + [new_hover_state] | |
else: | |
return [no_update]*12 + [new_hover_state] | |
else: | |
# hover仍然活跃时,暂停自动播放 | |
return [no_update]*12 + [hover_state] | |
return [no_update]*12 + [hover_state] | |
except Exception as e: | |
print(f"update_video_frames回调函数异常: {e}") | |
return [no_update]*12 + [hover_state] | |
# ------------------ 启动应用 ------------------ | |
if __name__ == '__main__': | |
app.run_server(host="0.0.0.0", port=7860, debug=False) |