keyframe / app.py
zijian2022's picture
Rename w52.py to app.py
25c8cf5 verified
raw
history blame
13.9 kB
# ------------------ 导入库 ------------------
import dash
from dash import dcc, html, Input, Output, State, callback_context, no_update
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import cv2
import base64
from scipy.ndimage import gaussian_filter1d
# ------------------ 加载数据 ------------------
df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
timestamps = df["timestamp"].values
delta_t = np.diff(timestamps)
time_for_plot = timestamps[1:]
action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
# ------------------ 视频路径 ------------------
video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
# ------------------ Dash 初始化 ------------------
app = dash.Dash(__name__)
server = app.server
# ------------------ 全局变量存储阴影信息 ------------------
all_shadows = {} # 存储所有关节的阴影信息
# ------------------ 视频帧提取函数 ------------------
def get_video_frame(video_path, time_in_seconds):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"❌ 无法打开视频: {video_path}")
return None
fps = cap.get(cv2.CAP_PROP_FPS)
frame_num = int(time_in_seconds * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
success, frame = cap.read()
cap.release()
if success:
_, buffer = cv2.imencode('.jpg', frame)
encoded = base64.b64encode(buffer).decode('utf-8')
return f"data:image/jpeg;base64,{encoded}"
else:
return None
def find_intervals(mask):
intervals = []
start = None
for i, val in enumerate(mask):
if val and start is None:
start = i
elif not val and start is not None:
intervals.append((start, i - 1))
start = None
if start is not None:
intervals.append((start, len(mask) - 1))
return intervals
def get_shadow_info(joint_name):
"""获取特定关节的所有红色阴影信息"""
angles = action_df[joint_name].values
velocity = np.diff(angles) / delta_t
smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
# 参数
vel_threshold = 0.5
highlight_width = 3
k = 2
shadows = []
# 低速区间阴影
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
low_speed_intervals = find_intervals(low_speed_mask)
for start, end in low_speed_intervals:
if end - start + 1 <= k:
shadows.append({
'type': 'low_speed',
'start_time': time_for_plot[start],
'end_time': time_for_plot[end],
'start_idx': start,
'end_idx': end
})
# 最大值阴影
max_idx = np.argmax(smoothed_angle)
s_max = max(0, max_idx - highlight_width)
e_max = min(len(time_for_plot) - 1, max_idx + highlight_width)
shadows.append({
'type': 'max_value',
'start_time': time_for_plot[s_max],
'end_time': time_for_plot[e_max],
'start_idx': s_max,
'end_idx': e_max
})
# 最小值阴影
min_idx = np.argmin(smoothed_angle)
s_min = max(0, min_idx - highlight_width)
e_min = min(len(time_for_plot) - 1, min_idx + highlight_width)
shadows.append({
'type': 'min_value',
'start_time': time_for_plot[s_min],
'end_time': time_for_plot[e_min],
'start_idx': s_min,
'end_idx': e_min
})
return shadows
def is_hover_in_shadow(hover_time, shadows):
"""检查hover时间是否在任何阴影内"""
for shadow in shadows:
if shadow['start_time'] <= hover_time <= shadow['end_time']:
return True
return False
def find_shadows_in_range(shadows, start_time, end_time):
"""找到指定时间范围内的所有阴影"""
shadows_in_range = []
for shadow in shadows:
# 检查阴影是否与指定范围有重叠
if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time):
shadows_in_range.append(shadow)
return shadows_in_range
# 预计算所有关节的阴影信息
for joint in columns:
all_shadows[joint] = get_shadow_info(joint)
# ------------------ 图表生成函数 ------------------
def generate_joint_graph(joint_name, idx, highlighted_shadows=None):
angles = action_df[joint_name].values
velocity = np.diff(angles) / delta_t
smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
# 参数
vel_threshold = 0.5
highlight_width = 3
k = 2
# 找低速区间
low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
low_speed_intervals = find_intervals(low_speed_mask)
# 找最大最小点
max_idx = np.argmax(smoothed_angle)
min_idx = np.argmin(smoothed_angle)
shapes = []
# 获取当前关节的阴影信息
current_shadows = all_shadows[joint_name]
# 正常的红色阴影
for shadow in current_shadows:
is_highlighted = False
if highlighted_shadows:
for h_shadow in highlighted_shadows:
if (shadow['start_time'] == h_shadow['start_time'] and
shadow['end_time'] == h_shadow['end_time']):
is_highlighted = True
break
color = "blue" if is_highlighted else "red"
opacity = 0.6 if is_highlighted else 0.3
shapes.append({
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": shadow['start_time'],
"x1": shadow['end_time'],
"y0": 0,
"y1": 1,
"fillcolor": color,
"opacity": opacity,
"line": {"width": 0}
})
return dcc.Graph(
id=f"graph-{idx}",
figure={
"data": [
go.Scatter(
x=time_for_plot,
y=smoothed_angle,
name="Angle",
line=dict(color='orange')
)
],
"layout": go.Layout(
title=joint_name,
xaxis={"title": "Time (s)"},
yaxis={"title": "Angle (deg)"},
shapes=shapes,
hovermode="x unified",
height=250,
margin=dict(t=30, b=30, l=50, r=50),
showlegend=False,
)
},
style={"height": "250px"}
)
# ------------------ 布局 ------------------
rows = []
# 关节图 + 双视频帧
for i, joint in enumerate(columns):
rows.append(html.Div([
html.Div(generate_joint_graph(joint, i), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}),
html.Div([
html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}),
html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"})
], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
], style={"marginBottom": "15px"}))
# 添加定时器和存储组件
rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0))
rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0}))
# 设置 layout
app.layout = html.Div(rows)
# ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
@app.callback(
[Output(f"graph-{i}", "figure") for i in range(6)],
[Input(f"graph-{i}", "hoverData") for i in range(6)],
[State(f"graph-{i}", "figure") for i in range(6)],
)
def update_shadow_highlighting(*args):
hover_datas = args[:6]
current_figures = args[6:]
ctx = dash.callback_context
# 检查是否有hover触发
if not ctx.triggered:
return [no_update] * 6
trigger_id = ctx.triggered[0]['prop_id']
# 如果不是hover触发,不更新
if 'hoverData' not in trigger_id:
return [no_update] * 6
# 提取触发的图表索引
graph_idx = int(trigger_id.split('-')[1].split('.')[0])
hover_data = hover_datas[graph_idx]
# 如果没有hover数据,恢复正常状态
if not hover_data or "points" not in hover_data or len(hover_data["points"]) == 0:
updated_figures = []
for i, joint in enumerate(columns):
updated_figures.append(generate_joint_graph(joint, i).figure)
return updated_figures
try:
hover_time = float(hover_data["points"][0]["x"])
triggered_joint = columns[graph_idx]
# 检查hover是否在红色阴影内
if not is_hover_in_shadow(hover_time, all_shadows[triggered_joint]):
# 如果不在阴影内,恢复正常状态
updated_figures = []
for i, joint in enumerate(columns):
updated_figures.append(generate_joint_graph(joint, i).figure)
return updated_figures
# 找到hover时间对应的时间戳索引
hover_idx = np.searchsorted(time_for_plot, hover_time)
# 计算前后10个时间戳的范围
start_idx = max(0, hover_idx - 20)
end_idx = min(len(time_for_plot) - 1, hover_idx + 20)
start_time = time_for_plot[start_idx]
end_time = time_for_plot[end_idx]
# 为每个关节生成更新的图表
updated_figures = []
for i, joint in enumerate(columns):
# 找到该关节在指定时间范围内的阴影
shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time)
# 生成带有高亮的图表
updated_figure = generate_joint_graph(joint, i, shadows_in_range)
updated_figures.append(updated_figure.figure)
return updated_figures
except Exception as e:
print(f"处理阴影高亮异常: {e}")
return [no_update] * 6
# ------------------ 回调:监听 hoverData 更新视频帧 ------------------
video_duration = timestamps[-1] - timestamps[0]
@app.callback(
[Output(f"video1-{i}", "src") for i in range(6)] +
[Output(f"video2-{i}", "src") for i in range(6)] +
[Output("hover-state-store", "data")],
[Input(f"graph-{i}", "hoverData") for i in range(6)] +
[Input("video-playback-interval", "n_intervals")],
[State("hover-state-store", "data")]
)
def update_video_frames(*args):
hover_datas = args[:-2]
interval_count = args[-2]
hover_state = args[-1]
# 获取触发回调的上下文
ctx = dash.callback_context
try:
# 检查是否有hover触发了回调
if ctx.triggered:
trigger_id = ctx.triggered[0]['prop_id']
# 如果是图表hover触发的
if 'hoverData' in trigger_id:
# 从trigger_id中提取图表索引
graph_idx = int(trigger_id.split('-')[1].split('.')[0])
hover_data = hover_datas[graph_idx]
if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
try:
hover_time = float(hover_data["points"][0]["x"])
frame1 = get_video_frame(video_path_1, hover_time)
frame2 = get_video_frame(video_path_2, hover_time)
# 更新hover状态为活跃
new_hover_state = {"active": True, "last_update": interval_count}
# 如果成功获取帧,返回所有视频的帧
if frame1 and frame2:
return [frame1]*6 + [frame2]*6 + [new_hover_state]
except Exception as e:
print(f"处理hover数据异常: {e}")
# 如果是interval触发的
if 'video-playback-interval' in trigger_id:
# 检查hover状态是否过期(超过3个interval周期没有更新)
hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3
if not hover_state.get("active", False) or hover_expired:
# 没有hover或hover已过期时才自动播放
t = timestamps[0] + (interval_count * 0.3) % video_duration
frame1 = get_video_frame(video_path_1, t)
frame2 = get_video_frame(video_path_2, t)
# 更新hover状态为非活跃
new_hover_state = {"active": False, "last_update": interval_count}
if frame1 and frame2:
return [frame1]*6 + [frame2]*6 + [new_hover_state]
else:
return [no_update]*12 + [new_hover_state]
else:
# hover仍然活跃时,暂停自动播放
return [no_update]*12 + [hover_state]
return [no_update]*12 + [hover_state]
except Exception as e:
print(f"update_video_frames回调函数异常: {e}")
return [no_update]*12 + [hover_state]
# ------------------ 启动应用 ------------------
if __name__ == '__main__':
app.run_server(host="0.0.0.0", port=7860, debug=False)