# ------------------ 导入库 ------------------ import dash from dash import dcc, html, Input, Output, State, callback_context, no_update import plotly.graph_objects as go import pandas as pd import numpy as np import cv2 import base64 from scipy.ndimage import gaussian_filter1d # ------------------ 加载数据 ------------------ df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet") columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"] timestamps = df["timestamp"].values delta_t = np.diff(timestamps) time_for_plot = timestamps[1:] action_df = pd.DataFrame(df["action"].tolist(), columns=columns) # ------------------ 视频路径 ------------------ video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4" video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4" # ------------------ Dash 初始化 ------------------ app = dash.Dash(__name__) server = app.server # ------------------ 全局变量存储阴影信息 ------------------ all_shadows = {} # 存储所有关节的阴影信息 # ------------------ 视频帧提取函数 ------------------ def get_video_frame(video_path, time_in_seconds): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"❌ 无法打开视频: {video_path}") return None fps = cap.get(cv2.CAP_PROP_FPS) frame_num = int(time_in_seconds * fps) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) success, frame = cap.read() cap.release() if success: _, buffer = cv2.imencode('.jpg', frame) encoded = base64.b64encode(buffer).decode('utf-8') return f"data:image/jpeg;base64,{encoded}" else: return None def find_intervals(mask): intervals = [] start = None for i, val in enumerate(mask): if val and start is None: start = i elif not val and start is not None: intervals.append((start, i - 1)) start = None if start is not None: intervals.append((start, len(mask) - 1)) return intervals def get_shadow_info(joint_name): """获取特定关节的所有红色阴影信息""" angles = action_df[joint_name].values velocity = np.diff(angles) / delta_t smoothed_velocity = gaussian_filter1d(velocity, sigma=1) smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) # 参数 vel_threshold = 0.5 highlight_width = 3 k = 2 shadows = [] # 低速区间阴影 low_speed_mask = np.abs(smoothed_velocity) < vel_threshold low_speed_intervals = find_intervals(low_speed_mask) for start, end in low_speed_intervals: if end - start + 1 <= k: shadows.append({ 'type': 'low_speed', 'start_time': time_for_plot[start], 'end_time': time_for_plot[end], 'start_idx': start, 'end_idx': end }) # 最大值阴影 max_idx = np.argmax(smoothed_angle) s_max = max(0, max_idx - highlight_width) e_max = min(len(time_for_plot) - 1, max_idx + highlight_width) shadows.append({ 'type': 'max_value', 'start_time': time_for_plot[s_max], 'end_time': time_for_plot[e_max], 'start_idx': s_max, 'end_idx': e_max }) # 最小值阴影 min_idx = np.argmin(smoothed_angle) s_min = max(0, min_idx - highlight_width) e_min = min(len(time_for_plot) - 1, min_idx + highlight_width) shadows.append({ 'type': 'min_value', 'start_time': time_for_plot[s_min], 'end_time': time_for_plot[e_min], 'start_idx': s_min, 'end_idx': e_min }) return shadows def is_hover_in_shadow(hover_time, shadows): """检查hover时间是否在任何阴影内""" for shadow in shadows: if shadow['start_time'] <= hover_time <= shadow['end_time']: return True return False def find_shadows_in_range(shadows, start_time, end_time): """找到指定时间范围内的所有阴影""" shadows_in_range = [] for shadow in shadows: # 检查阴影是否与指定范围有重叠 if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time): shadows_in_range.append(shadow) return shadows_in_range # 预计算所有关节的阴影信息 for joint in columns: all_shadows[joint] = get_shadow_info(joint) # ------------------ 图表生成函数 ------------------ def generate_joint_graph(joint_name, idx, highlighted_shadows=None): angles = action_df[joint_name].values velocity = np.diff(angles) / delta_t smoothed_velocity = gaussian_filter1d(velocity, sigma=1) smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) # 参数 vel_threshold = 0.5 highlight_width = 3 k = 2 # 找低速区间 low_speed_mask = np.abs(smoothed_velocity) < vel_threshold low_speed_intervals = find_intervals(low_speed_mask) # 找最大最小点 max_idx = np.argmax(smoothed_angle) min_idx = np.argmin(smoothed_angle) shapes = [] # 获取当前关节的阴影信息 current_shadows = all_shadows[joint_name] # 正常的红色阴影 for shadow in current_shadows: is_highlighted = False if highlighted_shadows: for h_shadow in highlighted_shadows: if (shadow['start_time'] == h_shadow['start_time'] and shadow['end_time'] == h_shadow['end_time']): is_highlighted = True break color = "blue" if is_highlighted else "red" opacity = 0.6 if is_highlighted else 0.3 shapes.append({ "type": "rect", "xref": "x", "yref": "paper", "x0": shadow['start_time'], "x1": shadow['end_time'], "y0": 0, "y1": 1, "fillcolor": color, "opacity": opacity, "line": {"width": 0} }) return dcc.Graph( id=f"graph-{idx}", figure={ "data": [ go.Scatter( x=time_for_plot, y=smoothed_angle, name="Angle", line=dict(color='orange') ) ], "layout": go.Layout( title=joint_name, xaxis={"title": "Time (s)"}, yaxis={"title": "Angle (deg)"}, shapes=shapes, hovermode="x unified", height=250, margin=dict(t=30, b=30, l=50, r=50), showlegend=False, ) }, style={"height": "250px"} ) # ------------------ 布局 ------------------ rows = [] # 关节图 + 双视频帧 for i, joint in enumerate(columns): rows.append(html.Div([ html.Div(generate_joint_graph(joint, i), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}), html.Div([ html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}), html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}) ], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"}) ], style={"marginBottom": "15px"})) # 添加定时器和存储组件 rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0)) rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0})) # 设置 layout app.layout = html.Div(rows) # ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------ @app.callback( [Output(f"graph-{i}", "figure") for i in range(6)], [Input(f"graph-{i}", "hoverData") for i in range(6)], [State(f"graph-{i}", "figure") for i in range(6)], ) def update_shadow_highlighting(*args): hover_datas = args[:6] current_figures = args[6:] ctx = dash.callback_context # 检查是否有hover触发 if not ctx.triggered: return [no_update] * 6 trigger_id = ctx.triggered[0]['prop_id'] # 如果不是hover触发,不更新 if 'hoverData' not in trigger_id: return [no_update] * 6 # 提取触发的图表索引 graph_idx = int(trigger_id.split('-')[1].split('.')[0]) hover_data = hover_datas[graph_idx] # 如果没有hover数据,恢复正常状态 if not hover_data or "points" not in hover_data or len(hover_data["points"]) == 0: updated_figures = [] for i, joint in enumerate(columns): updated_figures.append(generate_joint_graph(joint, i).figure) return updated_figures try: hover_time = float(hover_data["points"][0]["x"]) triggered_joint = columns[graph_idx] # 检查hover是否在红色阴影内 if not is_hover_in_shadow(hover_time, all_shadows[triggered_joint]): # 如果不在阴影内,恢复正常状态 updated_figures = [] for i, joint in enumerate(columns): updated_figures.append(generate_joint_graph(joint, i).figure) return updated_figures # 找到hover时间对应的时间戳索引 hover_idx = np.searchsorted(time_for_plot, hover_time) # 计算前后10个时间戳的范围 start_idx = max(0, hover_idx - 20) end_idx = min(len(time_for_plot) - 1, hover_idx + 20) start_time = time_for_plot[start_idx] end_time = time_for_plot[end_idx] # 为每个关节生成更新的图表 updated_figures = [] for i, joint in enumerate(columns): # 找到该关节在指定时间范围内的阴影 shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time) # 生成带有高亮的图表 updated_figure = generate_joint_graph(joint, i, shadows_in_range) updated_figures.append(updated_figure.figure) return updated_figures except Exception as e: print(f"处理阴影高亮异常: {e}") return [no_update] * 6 # ------------------ 回调:监听 hoverData 更新视频帧 ------------------ video_duration = timestamps[-1] - timestamps[0] @app.callback( [Output(f"video1-{i}", "src") for i in range(6)] + [Output(f"video2-{i}", "src") for i in range(6)] + [Output("hover-state-store", "data")], [Input(f"graph-{i}", "hoverData") for i in range(6)] + [Input("video-playback-interval", "n_intervals")], [State("hover-state-store", "data")] ) def update_video_frames(*args): hover_datas = args[:-2] interval_count = args[-2] hover_state = args[-1] # 获取触发回调的上下文 ctx = dash.callback_context try: # 检查是否有hover触发了回调 if ctx.triggered: trigger_id = ctx.triggered[0]['prop_id'] # 如果是图表hover触发的 if 'hoverData' in trigger_id: # 从trigger_id中提取图表索引 graph_idx = int(trigger_id.split('-')[1].split('.')[0]) hover_data = hover_datas[graph_idx] if hover_data and "points" in hover_data and len(hover_data["points"]) > 0: try: hover_time = float(hover_data["points"][0]["x"]) frame1 = get_video_frame(video_path_1, hover_time) frame2 = get_video_frame(video_path_2, hover_time) # 更新hover状态为活跃 new_hover_state = {"active": True, "last_update": interval_count} # 如果成功获取帧,返回所有视频的帧 if frame1 and frame2: return [frame1]*6 + [frame2]*6 + [new_hover_state] except Exception as e: print(f"处理hover数据异常: {e}") # 如果是interval触发的 if 'video-playback-interval' in trigger_id: # 检查hover状态是否过期(超过3个interval周期没有更新) hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3 if not hover_state.get("active", False) or hover_expired: # 没有hover或hover已过期时才自动播放 t = timestamps[0] + (interval_count * 0.3) % video_duration frame1 = get_video_frame(video_path_1, t) frame2 = get_video_frame(video_path_2, t) # 更新hover状态为非活跃 new_hover_state = {"active": False, "last_update": interval_count} if frame1 and frame2: return [frame1]*6 + [frame2]*6 + [new_hover_state] else: return [no_update]*12 + [new_hover_state] else: # hover仍然活跃时,暂停自动播放 return [no_update]*12 + [hover_state] return [no_update]*12 + [hover_state] except Exception as e: print(f"update_video_frames回调函数异常: {e}") return [no_update]*12 + [hover_state] # ------------------ 启动应用 ------------------ if __name__ == '__main__': app.run_server(host="0.0.0.0", port=7860, debug=False)