liuyiyang01 commited on
Commit
faec3ed
·
1 Parent(s): e4ec769

change structure(move util functions to app_utils.py)

Browse files
Files changed (3) hide show
  1. app.py +17 -629
  2. app_utils.py +581 -2
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,604 +1,18 @@
1
  import gradio as gr
2
- import requests
3
- import json
4
  import os
5
- import subprocess
6
- import uuid
7
- import time
8
- import cv2
9
- from typing import Optional, List
10
- import numpy as np
11
- from datetime import datetime, timedelta
12
- from collections import defaultdict
13
- import shutil
14
- from urllib.parse import urljoin
15
-
16
- # os.environ["SPACES_QUEUE_ENABLED"] = "true"
17
 
18
  from app_utils import (
 
 
 
 
 
 
19
  TMP_ROOT,
 
 
20
  )
21
 
22
- # 后端API配置(可配置化)
23
- BACKEND_URL = os.getenv("BACKEND_URL", "http://47.95.6.204:51001/")
24
- API_ENDPOINTS = {
25
- "submit_task": f"{BACKEND_URL}/predict/video",
26
- "query_status": f"{BACKEND_URL}/predict/task",
27
- "terminate_task": f"{BACKEND_URL}/predict/terminate"
28
- }
29
-
30
- # 模拟场景配置
31
- SCENE_CONFIGS = {
32
- "scene_1": {
33
- "description": "scene_1",
34
- "objects": ["milk carton", "ceramic bowl", "mug"],
35
- "preview_image": "assets/scene_1.png"
36
- },
37
- }
38
-
39
- # 可用模型列表
40
- MODEL_CHOICES = [
41
- "gr1",
42
- # "GR00T-N1",
43
- # "GR00T-1.5",
44
- # "Pi0",
45
- # "DP+CLIP",
46
- # "AcT+CLIP"
47
- ]
48
-
49
- ###############################################################################
50
-
51
- SESSION_TASKS = {}
52
- IP_REQUEST_RECORDS = defaultdict(list)
53
- IP_LIMIT = 5 # 每分钟最多请求次数
54
-
55
- def is_request_allowed(ip: str) -> bool:
56
- now = datetime.now()
57
- IP_REQUEST_RECORDS[ip] = [t for t in IP_REQUEST_RECORDS[ip] if now - t < timedelta(minutes=1)]
58
- if len(IP_REQUEST_RECORDS[ip]) < IP_LIMIT:
59
- IP_REQUEST_RECORDS[ip].append(now)
60
- return True
61
- return False
62
- ###############################################################################
63
-
64
-
65
- # 日志文件路径
66
- LOG_DIR = "logs"
67
- os.makedirs(LOG_DIR, exist_ok=True)
68
- ACCESS_LOG = os.path.join(LOG_DIR, "access.log")
69
- SUBMISSION_LOG = os.path.join(LOG_DIR, "submissions.log")
70
-
71
- def log_access(user_ip: str = None, user_agent: str = None):
72
- """记录用户访问日志"""
73
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
74
- log_entry = {
75
- "timestamp": timestamp,
76
- "type": "access",
77
- "user_ip": user_ip or "unknown",
78
- "user_agent": user_agent or "unknown"
79
- }
80
-
81
- with open(ACCESS_LOG, "a") as f:
82
- f.write(json.dumps(log_entry) + "\n")
83
-
84
- def log_submission(scene: str, prompt: str, model: str, max_step: int, user: str = "anonymous", res: str = "unknown"):
85
- """记录用户提交日志"""
86
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
87
- log_entry = {
88
- "timestamp": timestamp,
89
- "type": "submission",
90
- "user": user,
91
- "scene": scene,
92
- "prompt": prompt,
93
- "model": model,
94
- "max_step": str(max_step),
95
- "res": res
96
- }
97
-
98
- with open(SUBMISSION_LOG, "a") as f:
99
- f.write(json.dumps(log_entry) + "\n")
100
-
101
- # 记录访问
102
- def record_access(request: gr.Request):
103
- user_ip = request.client.host if request else "unknown"
104
- user_agent = request.headers.get("user-agent", "unknown")
105
- log_access(user_ip, user_agent)
106
- return update_log_display()
107
-
108
- def read_logs(log_type: str = "all", max_entries: int = 50) -> list:
109
- """读取日志文件"""
110
- logs = []
111
-
112
- if log_type in ["all", "access"]:
113
- try:
114
- with open(ACCESS_LOG, "r") as f:
115
- for line in f:
116
- logs.append(json.loads(line.strip()))
117
- except FileNotFoundError:
118
- pass
119
-
120
- if log_type in ["all", "submission"]:
121
- try:
122
- with open(SUBMISSION_LOG, "r") as f:
123
- for line in f:
124
- logs.append(json.loads(line.strip()))
125
- except FileNotFoundError:
126
- pass
127
-
128
- # 按时间戳排序,最新的在前
129
- logs.sort(key=lambda x: x["timestamp"], reverse=True)
130
- return logs[:max_entries]
131
-
132
- def format_logs_for_display(logs: list) -> str:
133
- """格式化日志用于显示"""
134
- if not logs:
135
- return "暂无日志记录"
136
-
137
- markdown = "### 系统访问日志\n\n"
138
- markdown += "| 时间 | 类型 | 用户/IP | 详细信息 |\n"
139
- markdown += "|------|------|---------|----------|\n"
140
-
141
- for log in logs:
142
- timestamp = log.get("timestamp", "unknown")
143
- log_type = "访问" if log.get("type") == "access" else "提交"
144
-
145
- if log_type == "访问":
146
- user = log.get("user_ip", "unknown")
147
- details = f"User-Agent: {log.get('user_agent', 'unknown')}"
148
- else:
149
- user = log.get("user", "anonymous")
150
- result = log.get('res', 'unknown')
151
- if result != "success":
152
- if len(result) > 40: # Adjust this threshold as needed
153
- result = f"{result[:20]}...{result[-20:]}"
154
- details = f"场景: {log.get('scene', 'unknown')}, 指令: {log.get('prompt', '')}, 模型: {log.get('model', 'unknown')}, max step: {log.get('max_step', '300')}, result: {result}"
155
-
156
- markdown += f"| {timestamp} | {log_type} | {user} | {details} |\n"
157
-
158
- return markdown
159
-
160
-
161
-
162
- ###############################################################################
163
- def list_public_oss_files(base_url: str) -> List[str]:
164
- """列出公共OSS文件夹中的所有图片文件"""
165
- # 注意:这需要OSS支持目录列表功能,或者你有预先知道的文件命名规则
166
- # 如果OSS不支持目录列表,可能需要后端API提供文件列表
167
- # 这里假设可以直接通过HTTP访问
168
-
169
- # 实际情况可能需要根据你的OSS具体配置调整
170
- # 这里只是一个示例实现
171
- try:
172
- response = requests.get(base_url)
173
- if response.status_code == 200:
174
- # 这里需要根据OSS返回的实际内容解析文件列表
175
- # 可能需要使用HTML解析器或正则表达式
176
- # 以下只是示例
177
- import re
178
- files = re.findall(r'href="([^"]+\.(?:jpg|png|jpeg))"', response.text)
179
- return sorted([urljoin(base_url, f) for f in files])
180
- return []
181
- except Exception as e:
182
- print(f"Error listing public OSS files: {e}")
183
- return []
184
-
185
- def download_public_file(url: str, local_path: str):
186
- """下载公开可访问的文件"""
187
- try:
188
- response = requests.get(url, stream=True)
189
- if response.status_code == 200:
190
- with open(local_path, 'wb') as f:
191
- for chunk in response.iter_content(1024):
192
- f.write(chunk)
193
- return True
194
- return False
195
- except Exception as e:
196
- print(f"Error downloading public file {url}: {e}")
197
- return False
198
-
199
- def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
200
- """
201
- 流式输出仿真结果,从公共OSS读取图片
202
-
203
- 参数:
204
- result_folder: OSS上包含生成图片的文件夹URL (从后端API返回)
205
- task_id: 后端任务ID用于状态查询
206
- request: Gradio请求对象
207
- fps: 输出视频的帧率
208
-
209
-
210
- 生成:
211
- 生成的视频文件路径 (分段输出)
212
- """
213
- # 初始化变量
214
- image_folder = urljoin(result_folder, "image/") # 确保以/结尾
215
- frame_buffer: List[np.ndarray] = []
216
- frames_per_segment = fps * 2 # 每2秒60帧
217
- processed_files = set()
218
- width, height = 0, 0
219
- last_status_check = 0
220
- status_check_interval = 5 # 每5秒检查一次后端状态
221
- max_time = 240
222
-
223
- # 创建临时目录存储下载的图片
224
- user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
225
- local_image_dir = os.path.join(user_dir, "tasks", "images")
226
- os.makedirs(local_image_dir, exist_ok=True)
227
-
228
- while max_time > 0:
229
- max_time -= 1
230
- current_time = time.time()
231
-
232
- # 定期检查后端状态
233
- if current_time - last_status_check > status_check_interval:
234
- status = get_task_status(task_id)
235
- print("status: ", status)
236
- if status.get("status") == "completed":
237
- # 确保处理完所有已生成的图片
238
- process_remaining_public_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
239
- if frame_buffer:
240
- yield create_video_segment(frame_buffer, fps, width, height, request)
241
- break
242
- elif status.get("status") == "failed":
243
- raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
244
- elif status.get("status") == "terminated":
245
- break
246
- last_status_check = current_time
247
-
248
- # 从公共OSS获取文件列表
249
- try:
250
- # 注意:这里假设可以直接列出OSS文件
251
- # 如果不行,可能需要后端API提供文件列表
252
- oss_files = list_public_oss_files(image_folder)
253
- new_files = [f for f in oss_files if f not in processed_files]
254
- has_new_frames = False
255
-
256
- for file_url in new_files:
257
- try:
258
- # 下载文件到本地
259
- filename = os.path.basename(file_url)
260
- local_path = os.path.join(local_image_dir, filename)
261
- if download_public_file(file_url, local_path):
262
- # 读取图片
263
- frame = cv2.imread(local_path)
264
- if frame is not None:
265
- if width == 0: # 第一次获取图像尺寸
266
- height, width = frame.shape[:2]
267
-
268
- frame_buffer.append(frame)
269
- processed_files.add(file_url)
270
- has_new_frames = True
271
- except Exception as e:
272
- print(f"Error processing {file_url}: {e}")
273
-
274
- # 如果有新帧且积累够60帧,输出视频片段
275
- if has_new_frames and len(frame_buffer) >= frames_per_segment:
276
- segment_frames = frame_buffer[:frames_per_segment]
277
- frame_buffer = frame_buffer[frames_per_segment:]
278
- yield create_video_segment(segment_frames, fps, width, height, request)
279
-
280
- except Exception as e:
281
- print(f"Error accessing public OSS: {e}")
282
-
283
- time.sleep(1) # 避免过于频繁检查
284
-
285
- if max_time <= 0:
286
- raise gr.Error("timeout 240s")
287
-
288
- def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, req: gr.Request) -> str:
289
- """创建视频片段"""
290
- user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
291
- os.makedirs(user_dir, exist_ok=True)
292
- video_chunk_path = os.path.join(user_dir, "tasks/video_chunk")
293
- os.makedirs(video_chunk_path, exist_ok=True)
294
- segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
295
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
296
- out = cv2.VideoWriter(segment_name, fourcc, fps, (width, height))
297
-
298
- for frame in frames:
299
- out.write(frame)
300
- out.release()
301
-
302
- return segment_name
303
-
304
- def process_remaining_public_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
305
- """处理公共OSS上剩余的图片"""
306
- try:
307
- oss_files = list_public_oss_files(oss_folder)
308
- new_files = [f for f in oss_files if f not in processed_files]
309
-
310
- for file_url in new_files:
311
- try:
312
- # 下载文件到本地
313
- filename = os.path.basename(file_url)
314
- local_path = os.path.join(local_dir, filename)
315
- if download_public_file(file_url, local_path):
316
- # 读取图片
317
- frame = cv2.imread(local_path)
318
- if frame is not None:
319
- frame_buffer.append(frame)
320
- processed_files.add(file_url)
321
- except Exception as e:
322
- print(f"Error processing remaining {file_url}: {e}")
323
- except Exception as e:
324
- print(f"Error accessing public OSS for remaining files: {e}")
325
-
326
-
327
-
328
-
329
- ###############################################################################
330
-
331
-
332
-
333
- def submit_to_backend(
334
- scene: str,
335
- prompt: str,
336
- model: str,
337
- max_step: int,
338
- user: str = "Gradio-user",
339
- ) -> dict:
340
- job_id = str(uuid.uuid4())
341
-
342
- data = {
343
- "scene_type": scene,
344
- "instruction": prompt,
345
- "model_type": model,
346
- "max_step": str(max_step)
347
- }
348
-
349
- payload = {
350
- "user": user,
351
- "task": "robot_manipulation",
352
- "job_id": job_id,
353
- "data": json.dumps(data)
354
- }
355
-
356
- try:
357
- headers = {"Content-Type": "application/json"}
358
- response = requests.post(
359
- API_ENDPOINTS["submit_task"],
360
- json=payload,
361
- headers=headers,
362
- timeout=10
363
- )
364
- return response.json()
365
- except Exception as e:
366
- return {"status": "error", "message": str(e)}
367
-
368
- def get_task_status(task_id: str) -> dict:
369
- """
370
- 查询任务状态
371
- """
372
- try:
373
- response = requests.get(
374
- f"{API_ENDPOINTS['query_status']}/{task_id}",
375
- timeout=5
376
- )
377
- return response.json()
378
- except Exception as e:
379
- return {"status": "error", "message": str(e)}
380
-
381
- def terminate_task(task_id: str) -> Optional[dict]:
382
- """
383
- 终止任务
384
- """
385
- try:
386
- response = requests.post(
387
- f"{API_ENDPOINTS['terminate_task']}/{task_id}",
388
- timeout=3
389
- )
390
- return response.json()
391
- except Exception as e:
392
- print(f"Error terminate task: {e}")
393
- return None
394
-
395
- def convert_to_h264(video_path):
396
- """
397
- 将视频转换为 H.264 编码的 MP4 格式
398
- 生成新文件路径在原路径基础上添加 _h264 后缀)
399
- """
400
- base, ext = os.path.splitext(video_path)
401
- video_path_h264 = f"{base}_h264.mp4"
402
-
403
- try:
404
- # 构建 FFmpeg 命令
405
- ffmpeg_cmd = [
406
- "ffmpeg",
407
- "-i", video_path,
408
- "-c:v", "libx264",
409
- "-preset", "slow",
410
- "-crf", "23",
411
- "-c:a", "aac",
412
- "-movflags", "+faststart",
413
- video_path_h264
414
- ]
415
-
416
- # 执行 FFmpeg 命令
417
- subprocess.run(ffmpeg_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
418
-
419
- # 检查输出文件是否存在
420
- if not os.path.exists(video_path_h264):
421
- raise FileNotFoundError(f"H.264 编码文件未生成: {video_path_h264}")
422
-
423
- return video_path_h264
424
-
425
- except subprocess.CalledProcessError as e:
426
- raise gr.Error(f"FFmpeg 转换失败: {e.stderr}")
427
- except Exception as e:
428
- raise gr.Error(f"转换过程中发生错误: {str(e)}")
429
-
430
- def run_simulation(
431
- scene: str,
432
- prompt: str,
433
- model: str,
434
- max_step: int,
435
- history: list,
436
- request: gr.Request
437
- ):
438
- """运行仿真并更新历史记录"""
439
- # 获取当前时间
440
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
441
- scene_desc = SCENE_CONFIGS.get(scene, {}).get("description", scene)
442
-
443
- # 记录用户提交
444
- user_ip = request.client.host if request else "unknown"
445
- session_id = request.session_hash
446
-
447
- if not is_request_allowed(user_ip):
448
- log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
449
- raise gr.Error("Too many requests from this IP. Please wait and try again one minute later.")
450
-
451
- # 提交任务到后端
452
- submission_result = submit_to_backend(scene, prompt, model, max_step, user_ip)
453
- print("submission_result: ", submission_result)
454
-
455
- if submission_result.get("status") != "pending":
456
- log_submission(scene, prompt, model, max_step, user_ip, "Submission failed")
457
- raise gr.Error(f"Submission failed: {submission_result.get('message', 'unknown issue')}")
458
-
459
- try:
460
- task_id = submission_result["task_id"]
461
- SESSION_TASKS[session_id] = task_id
462
-
463
- gr.Info(f"Simulation started, task_id: {task_id}")
464
- time.sleep(5)
465
- # 获取任务状态
466
- status = get_task_status(task_id)
467
- print("first status: ", status)
468
- result_folder = status.get("result", "")
469
- except Exception as e:
470
- log_submission(scene, prompt, model, max_step, user_ip, str(e))
471
- raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
472
-
473
-
474
- if not os.path.exists(result_folder):
475
- log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist")
476
- raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
477
-
478
-
479
- # 流式输出视频片段
480
- try:
481
- for video_path in stream_simulation_results(result_folder, task_id):
482
- if video_path:
483
- yield video_path, history
484
- except Exception as e:
485
- log_submission(scene, prompt, model, max_step, user_ip, str(e))
486
- raise gr.Error(f"Error while streaming: {str(e)}")
487
-
488
- # 获取任务状态
489
- status = get_task_status(task_id)
490
- print("status: ", status)
491
- if status.get("status") == "completed":
492
- video_path = os.path.join(status.get("result"), "manipulation.mp4")
493
- print("video_path: ", video_path)
494
- video_path = convert_to_h264(video_path)
495
-
496
- # 创建新的历史记录条目
497
- new_entry = {
498
- "timestamp": timestamp,
499
- "scene": scene,
500
- "model": model,
501
- "prompt": prompt,
502
- "max_step": max_step,
503
- "video_path": video_path,
504
- "task_id": task_id
505
- }
506
-
507
- # 将新条目添加到历史记录顶部
508
- updated_history = history + [new_entry]
509
-
510
- # 限制历史记录数量,避免内存问题
511
- if len(updated_history) > 10:
512
- updated_history = updated_history[:10]
513
-
514
- print("updated_history:", updated_history)
515
- log_submission(scene, prompt, model, max_step, user_ip, "success")
516
- gr.Info("Simulation completed successfully!")
517
- yield None, updated_history
518
-
519
- elif status.get("status") == "failed":
520
- log_submission(scene, prompt, model, max_step, user_ip, status.get('result', 'backend error'))
521
- raise gr.Error(f"Task execution failed: {status.get('result', 'backend unknown issue')}")
522
- yield None, history
523
-
524
- elif status.get("status") == "terminated":
525
- log_submission(scene, prompt, model, max_step, user_ip, "user end terminated")
526
- yield None, history
527
-
528
- else:
529
- log_submission(scene, prompt, model, max_step, user_ip, "missing task's status from backend (Pending?)")
530
- raise gr.Error("missing task's status from backend (Pending?)")
531
- yield None, history
532
-
533
-
534
- ###############################################################################
535
-
536
-
537
- def update_history_display(history: list) -> list:
538
- """更新历史记录显示"""
539
- print("更新历史记录显示")
540
- updates = []
541
-
542
- for i in range(10):
543
- if i < len(history): # 如果有历史记录,更新对应槽位
544
- entry = history[i]
545
- updates.extend([
546
- gr.update(visible=True), # 更新 Column 可见性
547
- gr.update(visible=True, label=f"# {i+1} | {entry['scene']} | {entry['model']} | {entry['prompt']}", open=(i+1==len(history))), # 更新 Accordion
548
- gr.update(value=entry['video_path'], visible=True, autoplay=False), # 更新 Video
549
- gr.update(value=f"{entry['timestamp']}") # 更新详细 Markdown
550
- ])
551
- else: # 如果没有历史记录,隐藏槽位
552
- updates.extend([
553
- gr.update(visible=False), # 隐藏 Column
554
- gr.update(visible=False), # 隐藏 Accordion
555
- gr.update(value=None, visible=False), # 清空 Video
556
- gr.update(value="") # 清空详细 Markdown
557
- ])
558
- print("更新完成!")
559
- return updates
560
-
561
- def update_scene_display(scene: str) -> tuple[str, Optional[str]]:
562
- """更新场景描述和预览图"""
563
- config = SCENE_CONFIGS.get(scene, {})
564
- desc = config.get("description", "No description")
565
- objects = ", ".join(config.get("objects", []))
566
- image = config.get("preview_image", None)
567
-
568
- markdown = f"**{desc}** \nObjects in this scene: {objects}"
569
- return markdown, image
570
-
571
- def update_log_display():
572
- """更新日志显示"""
573
- logs = read_logs()
574
- return format_logs_for_display(logs)
575
-
576
- ###############################################################################
577
-
578
-
579
- def cleanup_session(req: gr.Request):
580
- session_id = req.session_hash
581
- task_id = SESSION_TASKS.pop(session_id, None)
582
-
583
- if task_id:
584
- try:
585
- status = get_task_status(task_id)
586
- print("clean up check status: ", status)
587
- if status.get("status") == "pending":
588
- res = terminate_task(task_id)
589
- if res.get("status") == "success":
590
- print(f"已终止任务 {task_id}")
591
- else:
592
- print(f"终止任务失败 {task_id}: {res.get('status', 'unknown issue')}")
593
- except Exception as e:
594
- print(f"终止任务失败 {task_id}: {e}")
595
-
596
- user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
597
- shutil.rmtree(user_dir)
598
-
599
-
600
- ###############################################################################
601
-
602
  header_html = """
603
  <div style="display: flex; justify-content: space-between; align-items: center; width: 100%; margin-bottom: 20px; padding: 20px; background: linear-gradient(135deg, #528bdb 0%, #a7b5d0 100%); border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
604
  <div style="display: flex; align-items: center;">
@@ -626,10 +40,8 @@ header_html = """
626
  </div>
627
  """
628
 
629
- ###############################################################################
630
 
631
-
632
- # 自定义CSS样式
633
  custom_css = """
634
  #simulation-panel {
635
  border-radius: 8px;
@@ -688,29 +100,16 @@ def start_session(req: gr.Request):
688
  user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
689
  os.makedirs(user_dir, exist_ok=True)
690
 
691
-
692
-
693
-
694
-
695
- # 创建Gradio界面
696
  with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo:
697
  gr.HTML(header_html)
698
-
699
- # # 标题和描述
700
- # gr.Markdown("""
701
- # # 🤖 InternManip Model Inference Demo
702
- # ### Model trained on InternManip framework
703
- # """)
704
 
705
- # 存储历史记录的组件变量
706
  history_state = gr.State([])
707
 
708
  with gr.Row():
709
- # 左侧控制面板
710
  with gr.Column(elem_id="simulation-panel"):
711
  gr.Markdown("### Simulation Settings")
712
 
713
- # 场景选择
714
  scene_dropdown = gr.Dropdown(
715
  label="Choose a scene",
716
  choices=list(SCENE_CONFIGS.keys()),
@@ -718,7 +117,6 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
718
  interactive=True
719
  )
720
 
721
- # 场景描述预览
722
  scene_description = gr.Markdown("")
723
  scene_preview = gr.Image(
724
  label="Scene Preview",
@@ -732,7 +130,6 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
732
  outputs=[scene_description, scene_preview]
733
  )
734
 
735
- # 操作指令输入
736
  prompt_input = gr.Textbox(
737
  label="Manipulation Prompt",
738
  value="Move the milk carton to the top of the ceramic bowl.",
@@ -741,7 +138,6 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
741
  max_lines=4
742
  )
743
 
744
- # 模型选择
745
  model_dropdown = gr.Dropdown(
746
  label="Chose a pretrained model",
747
  choices=MODEL_CHOICES,
@@ -757,16 +153,11 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
757
  label="Max Steps"
758
  )
759
 
760
- # 提交按钮
761
  submit_btn = gr.Button("Apply and Start Simulation", variant="primary")
762
 
763
- # 右侧结果面板
764
  with gr.Column(elem_id="result-panel"):
765
  gr.Markdown("### Result")
766
 
767
- # progress_instruction = gr.Markdown("### Please click the botton on the left column to start.")
768
-
769
- # 视频输出
770
  video_output = gr.Video(
771
  label="Live",
772
  interactive=False,
@@ -775,22 +166,21 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
775
  streaming=True
776
  )
777
 
778
- # 历史记录显示区域
779
  with gr.Column() as history_container:
780
  gr.Markdown("### History")
781
  gr.Markdown("#### History will be reset after refresh")
782
 
783
- # 预创建10个历史记录槽位
784
  history_slots = []
785
  for i in range(10):
786
  with gr.Column(visible=False) as slot:
787
  with gr.Accordion(visible=False, open=False) as accordion:
788
- video = gr.Video(interactive=False) # 用于播放视频
789
- detail_md = gr.Markdown() # 用于显示详细信息
790
- history_slots.append((slot, accordion, video, detail_md)) # 存储所有相关组件
791
 
792
- # 添加日志显示区域
793
- with gr.Accordion("查看系统访问日志(DEV ONLY)", open=False):
794
  logs_display = gr.Markdown()
795
  refresh_logs_btn = gr.Button("刷新日志", variant="secondary")
796
 
@@ -799,7 +189,7 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
799
  outputs=logs_display
800
  )
801
 
802
- # 示例
803
  gr.Examples(
804
  examples=[
805
  ["scene_1", "Move the milk carton to the top of the ceramic bowl.", "gr1", 300],
@@ -824,7 +214,6 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
824
  outputs=logs_display
825
  )
826
 
827
- # 初始化场景描述和日志
828
  demo.load(
829
  start_session
830
  ).then(
@@ -845,6 +234,5 @@ with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo
845
  demo.unload(fn=cleanup_session)
846
 
847
 
848
- # 启动应用
849
  if __name__ == "__main__":
850
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import os
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from app_utils import (
5
+ update_scene_display,
6
+ update_log_display,
7
+ run_simulation,
8
+ update_history_display,
9
+ record_access,
10
+ cleanup_session,
11
  TMP_ROOT,
12
+ MODEL_CHOICES,
13
+ SCENE_CONFIGS,
14
  )
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  header_html = """
17
  <div style="display: flex; justify-content: space-between; align-items: center; width: 100%; margin-bottom: 20px; padding: 20px; background: linear-gradient(135deg, #528bdb 0%, #a7b5d0 100%); border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
18
  <div style="display: flex; align-items: center;">
 
40
  </div>
41
  """
42
 
 
43
 
44
+ # CSS style
 
45
  custom_css = """
46
  #simulation-panel {
47
  border-radius: 8px;
 
100
  user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
101
  os.makedirs(user_dir, exist_ok=True)
102
 
103
+ # Gradio UI
 
 
 
 
104
  with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo:
105
  gr.HTML(header_html)
 
 
 
 
 
 
106
 
 
107
  history_state = gr.State([])
108
 
109
  with gr.Row():
 
110
  with gr.Column(elem_id="simulation-panel"):
111
  gr.Markdown("### Simulation Settings")
112
 
 
113
  scene_dropdown = gr.Dropdown(
114
  label="Choose a scene",
115
  choices=list(SCENE_CONFIGS.keys()),
 
117
  interactive=True
118
  )
119
 
 
120
  scene_description = gr.Markdown("")
121
  scene_preview = gr.Image(
122
  label="Scene Preview",
 
130
  outputs=[scene_description, scene_preview]
131
  )
132
 
 
133
  prompt_input = gr.Textbox(
134
  label="Manipulation Prompt",
135
  value="Move the milk carton to the top of the ceramic bowl.",
 
138
  max_lines=4
139
  )
140
 
 
141
  model_dropdown = gr.Dropdown(
142
  label="Chose a pretrained model",
143
  choices=MODEL_CHOICES,
 
153
  label="Max Steps"
154
  )
155
 
 
156
  submit_btn = gr.Button("Apply and Start Simulation", variant="primary")
157
 
 
158
  with gr.Column(elem_id="result-panel"):
159
  gr.Markdown("### Result")
160
 
 
 
 
161
  video_output = gr.Video(
162
  label="Live",
163
  interactive=False,
 
166
  streaming=True
167
  )
168
 
 
169
  with gr.Column() as history_container:
170
  gr.Markdown("### History")
171
  gr.Markdown("#### History will be reset after refresh")
172
 
173
+ # Precreate 10 history slot
174
  history_slots = []
175
  for i in range(10):
176
  with gr.Column(visible=False) as slot:
177
  with gr.Accordion(visible=False, open=False) as accordion:
178
+ video = gr.Video(interactive=False)
179
+ detail_md = gr.Markdown() # display detail info
180
+ history_slots.append((slot, accordion, video, detail_md))
181
 
182
+ # Dev only log infomation
183
+ with gr.Accordion("查看系统访问日志(DEV ONLY)", open=False, visible=False):
184
  logs_display = gr.Markdown()
185
  refresh_logs_btn = gr.Button("刷新日志", variant="secondary")
186
 
 
189
  outputs=logs_display
190
  )
191
 
192
+ # Examples
193
  gr.Examples(
194
  examples=[
195
  ["scene_1", "Move the milk carton to the top of the ceramic bowl.", "gr1", 300],
 
214
  outputs=logs_display
215
  )
216
 
 
217
  demo.load(
218
  start_session
219
  ).then(
 
234
  demo.unload(fn=cleanup_session)
235
 
236
 
 
237
  if __name__ == "__main__":
238
  demo.launch()
app_utils.py CHANGED
@@ -1,6 +1,585 @@
1
  import gradio as gr
 
 
2
  import os
3
-
 
 
 
 
 
 
 
 
 
 
4
 
5
  TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
6
- os.makedirs(TMP_ROOT, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ import json
4
  import os
5
+ import subprocess
6
+ import uuid
7
+ import time
8
+ import cv2
9
+ from typing import Optional, List
10
+ import numpy as np
11
+ from datetime import datetime, timedelta
12
+ from collections import defaultdict
13
+ import shutil
14
+ from urllib.parse import urljoin
15
+ import oss2
16
 
17
  TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
18
+ os.makedirs(TMP_ROOT, exist_ok=True)
19
+
20
+
21
+ # 后端API配置(可配置化)
22
+ BACKEND_URL = os.getenv("BACKEND_URL")
23
+ API_ENDPOINTS = {
24
+ "submit_task": f"{BACKEND_URL}/predict/video",
25
+ "query_status": f"{BACKEND_URL}/predict/task",
26
+ "terminate_task": f"{BACKEND_URL}/predict/terminate"
27
+ }
28
+
29
+ # 模拟场景配置
30
+ SCENE_CONFIGS = {
31
+ "scene_1": {
32
+ "description": "scene_1",
33
+ "objects": ["milk carton", "ceramic bowl", "mug"],
34
+ "preview_image": "assets/scene_1.png"
35
+ },
36
+ }
37
+
38
+ # 可用模型列表
39
+ MODEL_CHOICES = [
40
+ "gr1",
41
+ # "GR00T-N1",
42
+ # "GR00T-1.5",
43
+ # "Pi0",
44
+ # "DP+CLIP",
45
+ # "AcT+CLIP"
46
+ ]
47
+
48
+ ###############################################################################
49
+
50
+ SESSION_TASKS = {}
51
+ IP_REQUEST_RECORDS = defaultdict(list)
52
+ IP_LIMIT = 5 # 每分钟最多请求次数
53
+
54
+ def is_request_allowed(ip: str) -> bool:
55
+ now = datetime.now()
56
+ IP_REQUEST_RECORDS[ip] = [t for t in IP_REQUEST_RECORDS[ip] if now - t < timedelta(minutes=1)]
57
+ if len(IP_REQUEST_RECORDS[ip]) < IP_LIMIT:
58
+ IP_REQUEST_RECORDS[ip].append(now)
59
+ return True
60
+ return False
61
+ ###############################################################################
62
+
63
+
64
+ # 日志文件路径
65
+ LOG_DIR = "logs"
66
+ os.makedirs(LOG_DIR, exist_ok=True)
67
+ ACCESS_LOG = os.path.join(LOG_DIR, "access.log")
68
+ SUBMISSION_LOG = os.path.join(LOG_DIR, "submissions.log")
69
+
70
+ def log_access(user_ip: str = None, user_agent: str = None):
71
+ """记录用户访问日志"""
72
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
73
+ log_entry = {
74
+ "timestamp": timestamp,
75
+ "type": "access",
76
+ "user_ip": user_ip or "unknown",
77
+ "user_agent": user_agent or "unknown"
78
+ }
79
+
80
+ with open(ACCESS_LOG, "a") as f:
81
+ f.write(json.dumps(log_entry) + "\n")
82
+
83
+ def log_submission(scene: str, prompt: str, model: str, max_step: int, user: str = "anonymous", res: str = "unknown"):
84
+ """记录用户提交日志"""
85
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
86
+ log_entry = {
87
+ "timestamp": timestamp,
88
+ "type": "submission",
89
+ "user": user,
90
+ "scene": scene,
91
+ "prompt": prompt,
92
+ "model": model,
93
+ "max_step": str(max_step),
94
+ "res": res
95
+ }
96
+
97
+ with open(SUBMISSION_LOG, "a") as f:
98
+ f.write(json.dumps(log_entry) + "\n")
99
+
100
+ # 记录访问
101
+ def record_access(request: gr.Request):
102
+ user_ip = request.client.host if request else "unknown"
103
+ user_agent = request.headers.get("user-agent", "unknown")
104
+ log_access(user_ip, user_agent)
105
+ return update_log_display()
106
+
107
+ def read_logs(log_type: str = "all", max_entries: int = 50) -> list:
108
+ """读取日志文件"""
109
+ logs = []
110
+
111
+ if log_type in ["all", "access"]:
112
+ try:
113
+ with open(ACCESS_LOG, "r") as f:
114
+ for line in f:
115
+ logs.append(json.loads(line.strip()))
116
+ except FileNotFoundError:
117
+ pass
118
+
119
+ if log_type in ["all", "submission"]:
120
+ try:
121
+ with open(SUBMISSION_LOG, "r") as f:
122
+ for line in f:
123
+ logs.append(json.loads(line.strip()))
124
+ except FileNotFoundError:
125
+ pass
126
+
127
+ # 按时间戳排序,最新的在前
128
+ logs.sort(key=lambda x: x["timestamp"], reverse=True)
129
+ return logs[:max_entries]
130
+
131
+ def format_logs_for_display(logs: list) -> str:
132
+ """格式化日志用于显示"""
133
+ if not logs:
134
+ return "暂无日志记录"
135
+
136
+ markdown = "### 系统访问日志\n\n"
137
+ markdown += "| 时间 | 类型 | 用户/IP | 详细信息 |\n"
138
+ markdown += "|------|------|---------|----------|\n"
139
+
140
+ for log in logs:
141
+ timestamp = log.get("timestamp", "unknown")
142
+ log_type = "访问" if log.get("type") == "access" else "提交"
143
+
144
+ if log_type == "访问":
145
+ user = log.get("user_ip", "unknown")
146
+ details = f"User-Agent: {log.get('user_agent', 'unknown')}"
147
+ else:
148
+ user = log.get("user", "anonymous")
149
+ result = log.get('res', 'unknown')
150
+ if result != "success":
151
+ if len(result) > 40: # Adjust this threshold as needed
152
+ result = f"{result[:20]}...{result[-20:]}"
153
+ details = f"场景: {log.get('scene', 'unknown')}, 指令: {log.get('prompt', '')}, 模型: {log.get('model', 'unknown')}, max step: {log.get('max_step', '300')}, result: {result}"
154
+
155
+ markdown += f"| {timestamp} | {log_type} | {user} | {details} |\n"
156
+
157
+ return markdown
158
+
159
+ ###############################################################################
160
+
161
+
162
+ # OSS配置
163
+ OSS_CONFIG = {
164
+ "access_key_id": os.getenv("OSS_ACCESS_KEY_ID"),
165
+ "access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET"),
166
+ "endpoint": os.getenv("OSS_ENDPOINT"),
167
+ "bucket_name": os.getenv("OSS_BUCKET_NAME")
168
+ }
169
+
170
+ auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"])
171
+ bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"])
172
+
173
+
174
+ def list_oss_files(folder_path: str) -> List[str]:
175
+ """列出OSS文件夹中的所有文件"""
176
+ files = []
177
+ for obj in oss2.ObjectIterator(bucket, prefix=folder_path):
178
+ if not obj.key.endswith('/'): # 排除目录本身
179
+ files.append(obj.key)
180
+ return sorted(files)
181
+
182
+ def download_oss_file(oss_path: str, local_path: str):
183
+ """从OSS下载文件到本地"""
184
+ bucket.get_object_to_file(oss_path, local_path)
185
+
186
+ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
187
+ """
188
+ 流式输出仿真结果,从OSS读取图片
189
+
190
+ 参数:
191
+ result_folder: OSS上包含生成图片的文件夹路径
192
+ task_id: 后端任务ID用于状态查询
193
+ fps: 输出视频的帧率
194
+ request: Gradio请求对象
195
+
196
+ 生成:
197
+ 生成的视频文件路径 (分段输出)
198
+ """
199
+ # 初始化变量
200
+ image_folder = os.path.join(result_folder, "image")
201
+ os.makedirs(image_folder, exist_ok=True)
202
+ frame_buffer: List[np.ndarray] = []
203
+ frames_per_segment = fps * 2 # 每2秒60帧
204
+ processed_files = set()
205
+ width, height = 0, 0
206
+ last_status_check = 0
207
+ status_check_interval = 5 # 每5秒检查一次后端状态
208
+ max_time = 240
209
+
210
+ # 创建临时目录存储下载的图片
211
+ user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
212
+ local_image_dir = os.path.join(user_dir, "tasks", "images")
213
+ os.makedirs(local_image_dir, exist_ok=True)
214
+
215
+ while max_time > 0:
216
+ max_time -= 1
217
+ current_time = time.time()
218
+
219
+ # 定期检查后端状态
220
+ if current_time - last_status_check > status_check_interval:
221
+ status = get_task_status(task_id)
222
+ print("status: ", status)
223
+ if status.get("status") == "completed":
224
+ # 确保处理完所有已生成的图片
225
+ process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
226
+ if frame_buffer:
227
+ yield create_video_segment(frame_buffer, fps, width, height, request)
228
+ break
229
+ elif status.get("status") == "failed":
230
+ raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
231
+ elif status.get("status") == "terminated":
232
+ break
233
+ last_status_check = current_time
234
+
235
+ # 从OSS获取文件列表
236
+ try:
237
+ oss_files = list_oss_files(image_folder)
238
+ new_files = [f for f in oss_files if f not in processed_files]
239
+ has_new_frames = False
240
+
241
+ for oss_path in new_files:
242
+ try:
243
+ # 下载文件到本地
244
+ filename = os.path.basename(oss_path)
245
+ local_path = os.path.join(local_image_dir, filename)
246
+ download_oss_file(oss_path, local_path)
247
+
248
+ # 读取图片
249
+ frame = cv2.imread(local_path)
250
+ if frame is not None:
251
+ if width == 0: # 第一次获取图像尺寸
252
+ height, width = frame.shape[:2]
253
+
254
+ frame_buffer.append(frame)
255
+ processed_files.add(oss_path)
256
+ has_new_frames = True
257
+ except Exception as e:
258
+ print(f"Error processing {oss_path}: {e}")
259
+
260
+ # 如果有新帧且积累够60帧,输出视频片段
261
+ if has_new_frames and len(frame_buffer) >= frames_per_segment:
262
+ segment_frames = frame_buffer[:frames_per_segment]
263
+ frame_buffer = frame_buffer[frames_per_segment:]
264
+ yield create_video_segment(segment_frames, fps, width, height, request)
265
+
266
+ except Exception as e:
267
+ print(f"Error accessing OSS: {e}")
268
+
269
+ time.sleep(1) # 避免过于频繁检查
270
+
271
+ if max_time <= 0:
272
+ raise gr.Error("timeout 240s")
273
+
274
+ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, req: gr.Request) -> str:
275
+ """创建视频片段"""
276
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
277
+ os.makedirs(user_dir, exist_ok=True)
278
+ video_chunk_path = os.path.join(user_dir, "tasks/video_chunk")
279
+ os.makedirs(video_chunk_path, exist_ok=True)
280
+ segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
281
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
282
+ out = cv2.VideoWriter(segment_name, fourcc, fps, (width, height))
283
+
284
+ for frame in frames:
285
+ out.write(frame)
286
+ out.release()
287
+
288
+ return segment_name
289
+
290
+ def process_remaining_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
291
+ """处理OSS上剩余的图片"""
292
+ try:
293
+ oss_files = list_oss_files(oss_folder)
294
+ new_files = [f for f in oss_files if f not in processed_files]
295
+
296
+ for oss_path in new_files:
297
+ try:
298
+ # 下载文件到本地
299
+ filename = os.path.basename(oss_path)
300
+ local_path = os.path.join(local_dir, filename)
301
+ download_oss_file(oss_path, local_path)
302
+
303
+ # 读取图片
304
+ frame = cv2.imread(local_path)
305
+ if frame is not None:
306
+ frame_buffer.append(frame)
307
+ processed_files.add(oss_path)
308
+ except Exception as e:
309
+ print(f"Error processing remaining {oss_path}: {e}")
310
+ except Exception as e:
311
+ print(f"Error accessing OSS for remaining files: {e}")
312
+
313
+
314
+
315
+
316
+
317
+ ###############################################################################
318
+
319
+
320
+
321
+ def submit_to_backend(
322
+ scene: str,
323
+ prompt: str,
324
+ model: str,
325
+ max_step: int,
326
+ user: str = "Gradio-user",
327
+ ) -> dict:
328
+ job_id = str(uuid.uuid4())
329
+
330
+ data = {
331
+ "scene_type": scene,
332
+ "instruction": prompt,
333
+ "model_type": model,
334
+ "max_step": str(max_step)
335
+ }
336
+
337
+ payload = {
338
+ "user": user,
339
+ "task": "robot_manipulation",
340
+ "job_id": job_id,
341
+ "data": json.dumps(data)
342
+ }
343
+
344
+ try:
345
+ headers = {"Content-Type": "application/json"}
346
+ response = requests.post(
347
+ API_ENDPOINTS["submit_task"],
348
+ json=payload,
349
+ headers=headers,
350
+ timeout=10
351
+ )
352
+ return response.json()
353
+ except Exception as e:
354
+ return {"status": "error", "message": str(e)}
355
+
356
+ def get_task_status(task_id: str) -> dict:
357
+ """
358
+ 查询任务状态
359
+ """
360
+ try:
361
+ response = requests.get(
362
+ f"{API_ENDPOINTS['query_status']}/{task_id}",
363
+ timeout=5
364
+ )
365
+ return response.json()
366
+ except Exception as e:
367
+ return {"status": "error", "message": str(e)}
368
+
369
+ def terminate_task(task_id: str) -> Optional[dict]:
370
+ """
371
+ 终止任务
372
+ """
373
+ try:
374
+ response = requests.post(
375
+ f"{API_ENDPOINTS['terminate_task']}/{task_id}",
376
+ timeout=3
377
+ )
378
+ return response.json()
379
+ except Exception as e:
380
+ print(f"Error terminate task: {e}")
381
+ return None
382
+
383
+ def convert_to_h264(video_path):
384
+ """
385
+ 将视频转换为 H.264 编码的 MP4 格式
386
+ 生成新文件路径在原路径基础上添加 _h264 后缀)
387
+ """
388
+ base, ext = os.path.splitext(video_path)
389
+ video_path_h264 = f"{base}_h264.mp4"
390
+
391
+ try:
392
+ # 构建 FFmpeg 命令
393
+ ffmpeg_cmd = [
394
+ "ffmpeg",
395
+ "-i", video_path,
396
+ "-c:v", "libx264",
397
+ "-preset", "slow",
398
+ "-crf", "23",
399
+ "-c:a", "aac",
400
+ "-movflags", "+faststart",
401
+ video_path_h264
402
+ ]
403
+
404
+ # 执行 FFmpeg 命令
405
+ subprocess.run(ffmpeg_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
406
+
407
+ # 检查输出文件是否存在
408
+ if not os.path.exists(video_path_h264):
409
+ raise FileNotFoundError(f"H.264 编码文件未生成: {video_path_h264}")
410
+
411
+ return video_path_h264
412
+
413
+ except subprocess.CalledProcessError as e:
414
+ raise gr.Error(f"FFmpeg 转换失败: {e.stderr}")
415
+ except Exception as e:
416
+ raise gr.Error(f"转换过程中发生错误: {str(e)}")
417
+
418
+ def run_simulation(
419
+ scene: str,
420
+ prompt: str,
421
+ model: str,
422
+ max_step: int,
423
+ history: list,
424
+ request: gr.Request
425
+ ):
426
+ """运行仿真并更新历史记录"""
427
+ # 获取当前时间
428
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
429
+ scene_desc = SCENE_CONFIGS.get(scene, {}).get("description", scene)
430
+
431
+ # 记录用户提交
432
+ user_ip = request.client.host if request else "unknown"
433
+ session_id = request.session_hash
434
+
435
+ if not is_request_allowed(user_ip):
436
+ log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
437
+ raise gr.Error("Too many requests from this IP. Please wait and try again one minute later.")
438
+
439
+ # 提交任务到后端
440
+ submission_result = submit_to_backend(scene, prompt, model, max_step, user_ip)
441
+ print("submission_result: ", submission_result)
442
+
443
+ if submission_result.get("status") != "pending":
444
+ log_submission(scene, prompt, model, max_step, user_ip, "Submission failed")
445
+ raise gr.Error(f"Submission failed: {submission_result.get('message', 'unknown issue')}")
446
+
447
+ try:
448
+ task_id = submission_result["task_id"]
449
+ SESSION_TASKS[session_id] = task_id
450
+
451
+ gr.Info(f"Simulation started, task_id: {task_id}")
452
+ time.sleep(5)
453
+ # 获取任务状态
454
+ status = get_task_status(task_id)
455
+ print("first status: ", status)
456
+ result_folder = status.get("result", "")
457
+ except Exception as e:
458
+ log_submission(scene, prompt, model, max_step, user_ip, str(e))
459
+ raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
460
+
461
+
462
+ if not os.path.exists(result_folder):
463
+ log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist")
464
+ raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
465
+
466
+
467
+ # 流式输出视频片段
468
+ try:
469
+ for video_path in stream_simulation_results(result_folder, task_id):
470
+ if video_path:
471
+ yield video_path, history
472
+ except Exception as e:
473
+ log_submission(scene, prompt, model, max_step, user_ip, str(e))
474
+ raise gr.Error(f"Error while streaming: {str(e)}")
475
+
476
+ # 获取任务状态
477
+ status = get_task_status(task_id)
478
+ print("status: ", status)
479
+ if status.get("status") == "completed":
480
+ video_path = os.path.join(status.get("result"), "manipulation.mp4")
481
+ print("video_path: ", video_path)
482
+ video_path = convert_to_h264(video_path)
483
+
484
+ # 创建新的历史记录条目
485
+ new_entry = {
486
+ "timestamp": timestamp,
487
+ "scene": scene,
488
+ "model": model,
489
+ "prompt": prompt,
490
+ "max_step": max_step,
491
+ "video_path": video_path,
492
+ "task_id": task_id
493
+ }
494
+
495
+ # 将新条目添加到历史记录顶部
496
+ updated_history = history + [new_entry]
497
+
498
+ # 限制历史记录数量,避免内存问题
499
+ if len(updated_history) > 10:
500
+ updated_history = updated_history[:10]
501
+
502
+ print("updated_history:", updated_history)
503
+ log_submission(scene, prompt, model, max_step, user_ip, "success")
504
+ gr.Info("Simulation completed successfully!")
505
+ yield None, updated_history
506
+
507
+ elif status.get("status") == "failed":
508
+ log_submission(scene, prompt, model, max_step, user_ip, status.get('result', 'backend error'))
509
+ raise gr.Error(f"Task execution failed: {status.get('result', 'backend unknown issue')}")
510
+ yield None, history
511
+
512
+ elif status.get("status") == "terminated":
513
+ log_submission(scene, prompt, model, max_step, user_ip, "user end terminated")
514
+ yield None, history
515
+
516
+ else:
517
+ log_submission(scene, prompt, model, max_step, user_ip, "missing task's status from backend (Pending?)")
518
+ raise gr.Error("missing task's status from backend (Pending?)")
519
+ yield None, history
520
+
521
+
522
+
523
+
524
+
525
+ def update_history_display(history: list) -> list:
526
+ """更新历史记录显示"""
527
+ print("更新历史记录显示")
528
+ updates = []
529
+
530
+ for i in range(10):
531
+ if i < len(history): # 如果有历史记录,更新对应槽位
532
+ entry = history[i]
533
+ updates.extend([
534
+ gr.update(visible=True), # 更新 Column 可见性
535
+ gr.update(visible=True, label=f"# {i+1} | {entry['scene']} | {entry['model']} | {entry['prompt']}", open=(i+1==len(history))), # 更新 Accordion
536
+ gr.update(value=entry['video_path'], visible=True, autoplay=False), # 更新 Video
537
+ gr.update(value=f"{entry['timestamp']}") # 更新详细 Markdown
538
+ ])
539
+ else: # 如果没有历史记录,隐藏槽位
540
+ updates.extend([
541
+ gr.update(visible=False), # 隐藏 Column
542
+ gr.update(visible=False), # 隐藏 Accordion
543
+ gr.update(value=None, visible=False), # 清空 Video
544
+ gr.update(value="") # 清空详细 Markdown
545
+ ])
546
+ print("更新完成!")
547
+ return updates
548
+
549
+ def update_scene_display(scene: str) -> tuple[str, Optional[str]]:
550
+ """更新场景描述和预览图"""
551
+ config = SCENE_CONFIGS.get(scene, {})
552
+ desc = config.get("description", "No description")
553
+ objects = ", ".join(config.get("objects", []))
554
+ image = config.get("preview_image", None)
555
+
556
+ markdown = f"**{desc}** \nObjects in this scene: {objects}"
557
+ return markdown, image
558
+
559
+ def update_log_display():
560
+ """更新日志显示"""
561
+ logs = read_logs()
562
+ return format_logs_for_display(logs)
563
+
564
+ ###############################################################################
565
+
566
+
567
+ def cleanup_session(req: gr.Request):
568
+ session_id = req.session_hash
569
+ task_id = SESSION_TASKS.pop(session_id, None)
570
+
571
+ if task_id:
572
+ try:
573
+ status = get_task_status(task_id)
574
+ print("clean up check status: ", status)
575
+ if status.get("status") == "pending":
576
+ res = terminate_task(task_id)
577
+ if res.get("status") == "success":
578
+ print(f"已终止任务 {task_id}")
579
+ else:
580
+ print(f"终止任务失败 {task_id}: {res.get('status', 'unknown issue')}")
581
+ except Exception as e:
582
+ print(f"终止任务失败 {task_id}: {e}")
583
+
584
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
585
+ shutil.rmtree(user_dir)
requirements.txt CHANGED
@@ -3,4 +3,5 @@ typing-extensions
3
  jsonlib-python3
4
  opencv-python
5
  numpy
6
- python-dateutil
 
 
3
  jsonlib-python3
4
  opencv-python
5
  numpy
6
+ python-dateutil
7
+ oss2