zijian2022 commited on
Commit
391ae4a
·
verified ·
1 Parent(s): d3b07ce

Upload 3 files

Browse files
Files changed (3) hide show
  1. dockerfile +25 -0
  2. requirements.txt +6 -0
  3. w52.py +378 -0
dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 使用轻量级的 Python 镜像
2
+ FROM python:3.9-slim
3
+
4
+ # 安装一些系统依赖,支持 OpenCV、Dash、视频处理等
5
+ RUN apt-get update && apt-get install -y \
6
+ ffmpeg \
7
+ libsm6 \
8
+ libxext6 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # 设置工作目录
12
+ WORKDIR /app
13
+
14
+ # 将当前目录下的所有文件复制到容器中
15
+ COPY . /app
16
+
17
+ # 安装 Python 依赖
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # 设置环境变量,指定容器监听的端口
21
+ ENV PORT=7860
22
+ ENV HOST=0.0.0.0
23
+
24
+ # 启动 Dash 应用(你也可以使用 gunicorn 等其他方式)
25
+ CMD ["python", "app.py"]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dash==3.1.1
2
+ numpy==2.3.1
3
+ opencv_python==4.11.0.86
4
+ pandas==2.3.1
5
+ plotly==5.24.1
6
+ scipy==1.16.0
w52.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------ 导入库 ------------------
2
+ import dash
3
+ from dash import dcc, html, Input, Output, State, callback_context, no_update
4
+ import plotly.graph_objects as go
5
+ import pandas as pd
6
+ import numpy as np
7
+ import cv2
8
+ import base64
9
+ from scipy.ndimage import gaussian_filter1d
10
+
11
+ # ------------------ 加载数据 ------------------
12
+ df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
13
+ columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
14
+ timestamps = df["timestamp"].values
15
+ delta_t = np.diff(timestamps)
16
+ time_for_plot = timestamps[1:]
17
+ action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
18
+
19
+ # ------------------ 视频路径 ------------------
20
+ video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
21
+ video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
22
+
23
+ # ------------------ Dash 初始化 ------------------
24
+ app = dash.Dash(__name__)
25
+ server = app.server
26
+
27
+ # ------------------ 全局变量存储阴影信息 ------------------
28
+ all_shadows = {} # 存储所有关节的阴影信息
29
+
30
+ # ------------------ 视频帧提取函数 ------------------
31
+ def get_video_frame(video_path, time_in_seconds):
32
+ cap = cv2.VideoCapture(video_path)
33
+ if not cap.isOpened():
34
+ print(f"❌ 无法打开视频: {video_path}")
35
+ return None
36
+ fps = cap.get(cv2.CAP_PROP_FPS)
37
+ frame_num = int(time_in_seconds * fps)
38
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
39
+ success, frame = cap.read()
40
+ cap.release()
41
+ if success:
42
+ _, buffer = cv2.imencode('.jpg', frame)
43
+ encoded = base64.b64encode(buffer).decode('utf-8')
44
+ return f"data:image/jpeg;base64,{encoded}"
45
+ else:
46
+ return None
47
+
48
+ def find_intervals(mask):
49
+ intervals = []
50
+ start = None
51
+ for i, val in enumerate(mask):
52
+ if val and start is None:
53
+ start = i
54
+ elif not val and start is not None:
55
+ intervals.append((start, i - 1))
56
+ start = None
57
+ if start is not None:
58
+ intervals.append((start, len(mask) - 1))
59
+ return intervals
60
+
61
+ def get_shadow_info(joint_name):
62
+ """获取特定关节的所有红色阴影信息"""
63
+ angles = action_df[joint_name].values
64
+ velocity = np.diff(angles) / delta_t
65
+
66
+ smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
67
+ smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
68
+
69
+ # 参数
70
+ vel_threshold = 0.5
71
+ highlight_width = 3
72
+ k = 2
73
+
74
+ shadows = []
75
+
76
+ # 低速区间阴影
77
+ low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
78
+ low_speed_intervals = find_intervals(low_speed_mask)
79
+
80
+ for start, end in low_speed_intervals:
81
+ if end - start + 1 <= k:
82
+ shadows.append({
83
+ 'type': 'low_speed',
84
+ 'start_time': time_for_plot[start],
85
+ 'end_time': time_for_plot[end],
86
+ 'start_idx': start,
87
+ 'end_idx': end
88
+ })
89
+
90
+ # 最大值阴影
91
+ max_idx = np.argmax(smoothed_angle)
92
+ s_max = max(0, max_idx - highlight_width)
93
+ e_max = min(len(time_for_plot) - 1, max_idx + highlight_width)
94
+ shadows.append({
95
+ 'type': 'max_value',
96
+ 'start_time': time_for_plot[s_max],
97
+ 'end_time': time_for_plot[e_max],
98
+ 'start_idx': s_max,
99
+ 'end_idx': e_max
100
+ })
101
+
102
+ # 最小值阴影
103
+ min_idx = np.argmin(smoothed_angle)
104
+ s_min = max(0, min_idx - highlight_width)
105
+ e_min = min(len(time_for_plot) - 1, min_idx + highlight_width)
106
+ shadows.append({
107
+ 'type': 'min_value',
108
+ 'start_time': time_for_plot[s_min],
109
+ 'end_time': time_for_plot[e_min],
110
+ 'start_idx': s_min,
111
+ 'end_idx': e_min
112
+ })
113
+
114
+ return shadows
115
+
116
+ def is_hover_in_shadow(hover_time, shadows):
117
+ """检查hover时间是否在任何阴影内"""
118
+ for shadow in shadows:
119
+ if shadow['start_time'] <= hover_time <= shadow['end_time']:
120
+ return True
121
+ return False
122
+
123
+ def find_shadows_in_range(shadows, start_time, end_time):
124
+ """找到指定时间范围内的所有阴影"""
125
+ shadows_in_range = []
126
+ for shadow in shadows:
127
+ # 检查阴影是否与指定范围有重叠
128
+ if not (shadow['end_time'] < start_time or shadow['start_time'] > end_time):
129
+ shadows_in_range.append(shadow)
130
+ return shadows_in_range
131
+
132
+ # 预计算所有关节的阴影信息
133
+ for joint in columns:
134
+ all_shadows[joint] = get_shadow_info(joint)
135
+
136
+ # ------------------ 图表生成函数 ------------------
137
+ def generate_joint_graph(joint_name, idx, highlighted_shadows=None):
138
+ angles = action_df[joint_name].values
139
+ velocity = np.diff(angles) / delta_t
140
+
141
+ smoothed_velocity = gaussian_filter1d(velocity, sigma=1)
142
+ smoothed_angle = gaussian_filter1d(angles[1:], sigma=1)
143
+
144
+ # 参数
145
+ vel_threshold = 0.5
146
+ highlight_width = 3
147
+ k = 2
148
+
149
+ # 找低速区间
150
+ low_speed_mask = np.abs(smoothed_velocity) < vel_threshold
151
+ low_speed_intervals = find_intervals(low_speed_mask)
152
+
153
+ # 找最大最小点
154
+ max_idx = np.argmax(smoothed_angle)
155
+ min_idx = np.argmin(smoothed_angle)
156
+
157
+ shapes = []
158
+
159
+ # 获取当前关节的阴影信息
160
+ current_shadows = all_shadows[joint_name]
161
+
162
+ # 正常的红色阴影
163
+ for shadow in current_shadows:
164
+ is_highlighted = False
165
+ if highlighted_shadows:
166
+ for h_shadow in highlighted_shadows:
167
+ if (shadow['start_time'] == h_shadow['start_time'] and
168
+ shadow['end_time'] == h_shadow['end_time']):
169
+ is_highlighted = True
170
+ break
171
+
172
+ color = "blue" if is_highlighted else "red"
173
+ opacity = 0.6 if is_highlighted else 0.3
174
+
175
+ shapes.append({
176
+ "type": "rect",
177
+ "xref": "x",
178
+ "yref": "paper",
179
+ "x0": shadow['start_time'],
180
+ "x1": shadow['end_time'],
181
+ "y0": 0,
182
+ "y1": 1,
183
+ "fillcolor": color,
184
+ "opacity": opacity,
185
+ "line": {"width": 0}
186
+ })
187
+
188
+ return dcc.Graph(
189
+ id=f"graph-{idx}",
190
+ figure={
191
+ "data": [
192
+ go.Scatter(
193
+ x=time_for_plot,
194
+ y=smoothed_angle,
195
+ name="Angle",
196
+ line=dict(color='orange')
197
+ )
198
+ ],
199
+ "layout": go.Layout(
200
+ title=joint_name,
201
+ xaxis={"title": "Time (s)"},
202
+ yaxis={"title": "Angle (deg)"},
203
+ shapes=shapes,
204
+ hovermode="x unified",
205
+ height=250,
206
+ margin=dict(t=30, b=30, l=50, r=50),
207
+ showlegend=False,
208
+ )
209
+ },
210
+ style={"height": "250px"}
211
+ )
212
+
213
+ # ------------------ 布局 ------------------
214
+ rows = []
215
+
216
+ # 关节图 + 双视频帧
217
+ for i, joint in enumerate(columns):
218
+ rows.append(html.Div([
219
+ html.Div(generate_joint_graph(joint, i), style={"width": "60%", "display": "inline-block", "verticalAlign": "top"}),
220
+ html.Div([
221
+ html.Img(id=f"video1-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"}),
222
+ html.Img(id=f"video2-{i}", style={"width": "49%", "height": "180px", "objectFit": "contain", "display": "inline-block"})
223
+ ], style={"width": "38%", "display": "inline-block", "paddingLeft": "2%"})
224
+ ], style={"marginBottom": "15px"}))
225
+
226
+ # 添加定时器和存储组件
227
+ rows.append(dcc.Interval(id="video-playback-interval", interval=300, n_intervals=0))
228
+ rows.append(dcc.Store(id="hover-state-store", data={"active": False, "last_update": 0}))
229
+
230
+ # 设置 layout
231
+ app.layout = html.Div(rows)
232
+
233
+ # ------------------ 回调:监听 hoverData 并更新阴影高亮 ------------------
234
+ @app.callback(
235
+ [Output(f"graph-{i}", "figure") for i in range(6)],
236
+ [Input(f"graph-{i}", "hoverData") for i in range(6)],
237
+ [State(f"graph-{i}", "figure") for i in range(6)],
238
+ )
239
+ def update_shadow_highlighting(*args):
240
+ hover_datas = args[:6]
241
+ current_figures = args[6:]
242
+
243
+ ctx = dash.callback_context
244
+
245
+ # 检查是否有hover触发
246
+ if not ctx.triggered:
247
+ return [no_update] * 6
248
+
249
+ trigger_id = ctx.triggered[0]['prop_id']
250
+
251
+ # 如果不是hover触发,不更新
252
+ if 'hoverData' not in trigger_id:
253
+ return [no_update] * 6
254
+
255
+ # 提取触发的图表索引
256
+ graph_idx = int(trigger_id.split('-')[1].split('.')[0])
257
+ hover_data = hover_datas[graph_idx]
258
+
259
+ # 如果没有hover数据,恢复正常状态
260
+ if not hover_data or "points" not in hover_data or len(hover_data["points"]) == 0:
261
+ updated_figures = []
262
+ for i, joint in enumerate(columns):
263
+ updated_figures.append(generate_joint_graph(joint, i).figure)
264
+ return updated_figures
265
+
266
+ try:
267
+ hover_time = float(hover_data["points"][0]["x"])
268
+ triggered_joint = columns[graph_idx]
269
+
270
+ # 检查hover是否在红色阴影内
271
+ if not is_hover_in_shadow(hover_time, all_shadows[triggered_joint]):
272
+ # 如果不在阴影内,恢复正常状态
273
+ updated_figures = []
274
+ for i, joint in enumerate(columns):
275
+ updated_figures.append(generate_joint_graph(joint, i).figure)
276
+ return updated_figures
277
+
278
+ # 找到hover时间对应的时间戳索引
279
+ hover_idx = np.searchsorted(time_for_plot, hover_time)
280
+
281
+ # 计算前后10个时间戳的范围
282
+ start_idx = max(0, hover_idx - 20)
283
+ end_idx = min(len(time_for_plot) - 1, hover_idx + 20)
284
+ start_time = time_for_plot[start_idx]
285
+ end_time = time_for_plot[end_idx]
286
+
287
+ # 为每个关节生成更新的图表
288
+ updated_figures = []
289
+ for i, joint in enumerate(columns):
290
+ # 找到该关节在指定时间范围内的阴影
291
+ shadows_in_range = find_shadows_in_range(all_shadows[joint], start_time, end_time)
292
+
293
+ # 生成带有高亮的图表
294
+ updated_figure = generate_joint_graph(joint, i, shadows_in_range)
295
+ updated_figures.append(updated_figure.figure)
296
+
297
+ return updated_figures
298
+
299
+ except Exception as e:
300
+ print(f"处理阴影高亮异常: {e}")
301
+ return [no_update] * 6
302
+
303
+ # ------------------ 回调:监听 hoverData 更新视频帧 ------------------
304
+ video_duration = timestamps[-1] - timestamps[0]
305
+
306
+ @app.callback(
307
+ [Output(f"video1-{i}", "src") for i in range(6)] +
308
+ [Output(f"video2-{i}", "src") for i in range(6)] +
309
+ [Output("hover-state-store", "data")],
310
+ [Input(f"graph-{i}", "hoverData") for i in range(6)] +
311
+ [Input("video-playback-interval", "n_intervals")],
312
+ [State("hover-state-store", "data")]
313
+ )
314
+ def update_video_frames(*args):
315
+ hover_datas = args[:-2]
316
+ interval_count = args[-2]
317
+ hover_state = args[-1]
318
+
319
+ # 获取触发回调的上下文
320
+ ctx = dash.callback_context
321
+
322
+ try:
323
+ # 检查是否有hover触发了回调
324
+ if ctx.triggered:
325
+ trigger_id = ctx.triggered[0]['prop_id']
326
+
327
+ # 如果是图表hover触发的
328
+ if 'hoverData' in trigger_id:
329
+ # 从trigger_id中提取图表索引
330
+ graph_idx = int(trigger_id.split('-')[1].split('.')[0])
331
+ hover_data = hover_datas[graph_idx]
332
+
333
+ if hover_data and "points" in hover_data and len(hover_data["points"]) > 0:
334
+ try:
335
+ hover_time = float(hover_data["points"][0]["x"])
336
+ frame1 = get_video_frame(video_path_1, hover_time)
337
+ frame2 = get_video_frame(video_path_2, hover_time)
338
+
339
+ # 更新hover状态为活跃
340
+ new_hover_state = {"active": True, "last_update": interval_count}
341
+
342
+ # 如果成功获取帧,返回所有视频的帧
343
+ if frame1 and frame2:
344
+ return [frame1]*6 + [frame2]*6 + [new_hover_state]
345
+ except Exception as e:
346
+ print(f"处理hover数据异常: {e}")
347
+
348
+ # 如果是interval触发的
349
+ if 'video-playback-interval' in trigger_id:
350
+ # 检查hover状态是否过期(超过3个interval周期没有更新)
351
+ hover_expired = (interval_count - hover_state.get("last_update", 0)) > 3
352
+
353
+ if not hover_state.get("active", False) or hover_expired:
354
+ # 没有hover或hover已过期时才自动播放
355
+ t = timestamps[0] + (interval_count * 0.3) % video_duration
356
+ frame1 = get_video_frame(video_path_1, t)
357
+ frame2 = get_video_frame(video_path_2, t)
358
+
359
+ # 更新hover状态为非活跃
360
+ new_hover_state = {"active": False, "last_update": interval_count}
361
+
362
+ if frame1 and frame2:
363
+ return [frame1]*6 + [frame2]*6 + [new_hover_state]
364
+ else:
365
+ return [no_update]*12 + [new_hover_state]
366
+ else:
367
+ # hover仍然活跃时,暂停自动播放
368
+ return [no_update]*12 + [hover_state]
369
+
370
+ return [no_update]*12 + [hover_state]
371
+
372
+ except Exception as e:
373
+ print(f"update_video_frames回调函数异常: {e}")
374
+ return [no_update]*12 + [hover_state]
375
+
376
+ # ------------------ 启动应用 ------------------
377
+ if __name__ == '__main__':
378
+ app.run_server(host="0.0.0.0", port=7860, debug=False)