change to oss
Browse files- OSS_MIGRATION_SUMMARY.md +79 -0
- README.md +72 -12
- app.py +37 -14
- config.py +1 -1
- oss_utils.py +64 -0
- requirements.txt +1 -0
- simulation.py +126 -44
OSS_MIGRATION_SUMMARY.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OSS 集成修改总结
|
2 |
+
|
3 |
+
## 修改内容
|
4 |
+
|
5 |
+
### 1. 新增文件
|
6 |
+
|
7 |
+
- **oss_utils.py**: 新建的 OSS 工具模块,包含:
|
8 |
+
- OSS 客户端初始化
|
9 |
+
- 文件列表获取 (`list_oss_files`)
|
10 |
+
- 文件下载 (`download_oss_file`)
|
11 |
+
- 文件存在检查 (`oss_file_exists`)
|
12 |
+
- 临时目录管理 (`get_user_tmp_dir`, `cleanup_user_tmp_dir`)
|
13 |
+
|
14 |
+
- **.env.example**: OSS 环境变量配置示例
|
15 |
+
|
16 |
+
### 2. 修改的文件
|
17 |
+
|
18 |
+
#### simulation.py
|
19 |
+
- **导入 OSS 模块**: 添加了 `oss_utils` 的导入
|
20 |
+
- **重写 `stream_simulation_results` 函数**:
|
21 |
+
- 支持从 OSS 读取图像文件
|
22 |
+
- 使用用户会话级别的临时目录
|
23 |
+
- 改进的错误处理和日志记录
|
24 |
+
- **重写 `create_video_segment` 函数**:
|
25 |
+
- 使用用户特定的临时目录
|
26 |
+
- 更好的目录管理
|
27 |
+
- **新增 `process_remaining_oss_images` 函数**:
|
28 |
+
- 处理 OSS 上剩余的图像文件
|
29 |
+
- **改进 `convert_to_h264` 函数**:
|
30 |
+
- 更好的 ffmpeg 路径查找
|
31 |
+
- 改进的错误处理
|
32 |
+
|
33 |
+
#### app.py
|
34 |
+
- **导入 OSS 模块**: 添加了 `oss_utils` 相关导入
|
35 |
+
- **重写 `run_simulation` 函数**:
|
36 |
+
- 不再检查本地目录是否存在
|
37 |
+
- 从 OSS 下载最终视频文件
|
38 |
+
- 使用临时目录管理
|
39 |
+
- **更新 `cleanup_session` 函数**:
|
40 |
+
- 添加了用户临时目录清理
|
41 |
+
|
42 |
+
#### requirements.txt
|
43 |
+
- **添加 OSS 依赖**: `oss2>=2.15.0`
|
44 |
+
|
45 |
+
#### README.md
|
46 |
+
- **完全重写**: 添加了 OSS 配置说明和使用指南
|
47 |
+
|
48 |
+
## 主要变化
|
49 |
+
|
50 |
+
### 1. 数据源变更
|
51 |
+
- **之前**: 从本地文件系统读取图像和视频
|
52 |
+
- **现在**: 从阿里云 OSS 读取数据
|
53 |
+
|
54 |
+
### 2. 临时文件管理
|
55 |
+
- **之前**: 使用固定的系统目录
|
56 |
+
- **现在**: 为每个用户会话创建独立的临时目录
|
57 |
+
|
58 |
+
### 3. 错误处理
|
59 |
+
- **之前**: 基本的错误处理
|
60 |
+
- **现在**: 更全面的错误处理和日志记录
|
61 |
+
|
62 |
+
### 4. 文件下载
|
63 |
+
- **之前**: 直接读取本地文件
|
64 |
+
- **现在**: 从 OSS 流式下载文件到临时目录
|
65 |
+
|
66 |
+
## 配置要求
|
67 |
+
|
68 |
+
使用此修改后的代码需要:
|
69 |
+
|
70 |
+
1. 安装 `oss2` Python 包
|
71 |
+
2. 配置以下环境变量:
|
72 |
+
- `OSS_ACCESS_KEY_ID`
|
73 |
+
- `OSS_ACCESS_KEY_SECRET`
|
74 |
+
- `OSS_ENDPOINT`
|
75 |
+
- `OSS_BUCKET_NAME`
|
76 |
+
|
77 |
+
## 向后兼容性
|
78 |
+
|
79 |
+
这些修改保持了 API 的兼容性,但改变了数据源。如果需要保持本地文件系统的支持,可以在 `oss_utils.py` 中添加回退逻辑。
|
README.md
CHANGED
@@ -1,12 +1,72 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# InternNav 评估演示
|
2 |
+
|
3 |
+
这是一个基于 Gradio 的 InternNav 模型推理演示应用,支持从阿里云 OSS 读取视频数据。
|
4 |
+
|
5 |
+
## 功能特性
|
6 |
+
|
7 |
+
- 🤖 支持多种导航模型 (rdp, cma)
|
8 |
+
- 🎯 支持多种模式 (vlnPE, vlnCE)
|
9 |
+
- 🎬 实时流式视频输出
|
10 |
+
- ☁️ 从阿里云 OSS 读取数据
|
11 |
+
- � 用户访问日志记录
|
12 |
+
- 🔒 IP 频率限制保护
|
13 |
+
|
14 |
+
## 环境配置
|
15 |
+
|
16 |
+
### 1. 安装依赖
|
17 |
+
|
18 |
+
```bash
|
19 |
+
pip install -r requirements.txt
|
20 |
+
```
|
21 |
+
|
22 |
+
### 2. 配置环境变量
|
23 |
+
|
24 |
+
复制 `.env.example` 为 `.env` 并填入您的 OSS 配置:
|
25 |
+
|
26 |
+
```bash
|
27 |
+
cp .env.example .env
|
28 |
+
```
|
29 |
+
|
30 |
+
编辑 `.env` 文件:
|
31 |
+
|
32 |
+
```env
|
33 |
+
OSS_ACCESS_KEY_ID=your_access_key_id
|
34 |
+
OSS_ACCESS_KEY_SECRET=your_access_key_secret
|
35 |
+
OSS_ENDPOINT=your_oss_endpoint
|
36 |
+
OSS_BUCKET_NAME=your_bucket_name
|
37 |
+
BACKEND_URL=http://47.95.6.204:51001
|
38 |
+
```
|
39 |
+
|
40 |
+
### 3. 运行应用
|
41 |
+
|
42 |
+
```bash
|
43 |
+
python app.py
|
44 |
+
```
|
45 |
+
|
46 |
+
## 文件结构
|
47 |
+
|
48 |
+
```
|
49 |
+
├── app.py # 主应用入口
|
50 |
+
├── config.py # 配置文件
|
51 |
+
├── backend_api.py # 后端 API 交互
|
52 |
+
├── simulation.py # 仿真和视频处理 (支持 OSS)
|
53 |
+
├── oss_utils.py # OSS 工具函数
|
54 |
+
├── logging_utils.py # 日志工具
|
55 |
+
├── ui_components.py # UI 组件
|
56 |
+
├── requirements.txt # Python 依赖
|
57 |
+
├── assets/ # 静态资源
|
58 |
+
└── tmp/ # 临时文件目录
|
59 |
+
```
|
60 |
+
|
61 |
+
## OSS 集成
|
62 |
+
|
63 |
+
应用现在完全支持从阿里云 OSS 读取数据:
|
64 |
+
|
65 |
+
- 图像文件从 OSS 流式下载
|
66 |
+
- 视频文件从 OSS 下载到本地临时目录
|
67 |
+
- 自动清理用户会话的临时文件
|
68 |
+
- 支持断点续传和错误恢复
|
69 |
+
|
70 |
+
## 配置参考
|
71 |
+
|
72 |
+
查看 Hugging Face Spaces 配置文档:https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -6,6 +6,7 @@ from backend_api import submit_to_backend, get_task_status, get_task_result
|
|
6 |
from logging_utils import log_access, log_submission, is_request_allowed
|
7 |
from simulation import stream_simulation_results, convert_to_h264
|
8 |
from ui_components import update_history_display, update_scene_display, update_log_display, get_scene_instruction
|
|
|
9 |
import os
|
10 |
from datetime import datetime
|
11 |
|
@@ -16,40 +17,62 @@ def run_simulation(scene, model, mode, prompt, history, request: gr.Request):
|
|
16 |
scene_desc = SCENE_CONFIGS.get(scene, {}).get("description", scene)
|
17 |
user_ip = request.client.host if request else "unknown"
|
18 |
session_id = request.session_hash
|
|
|
19 |
if not is_request_allowed(user_ip):
|
20 |
log_submission(scene, prompt, model, user_ip, "IP blocked temporarily")
|
21 |
raise gr.Error("Too many requests from this IP. Please wait and try again one minute later.")
|
22 |
-
|
23 |
-
#
|
24 |
submission_result = submit_to_backend(scene, prompt, mode, model, user_ip)
|
25 |
if submission_result.get("status") != "pending":
|
26 |
log_submission(scene, prompt, model, user_ip, "Submission failed")
|
27 |
raise gr.Error(f"Submission failed: {submission_result.get('message', 'unknown issue')}")
|
|
|
28 |
try:
|
29 |
task_id = submission_result["task_id"]
|
30 |
SESSION_TASKS[session_id] = task_id
|
31 |
gr.Info(f"Simulation started, task_id: {task_id}")
|
|
|
32 |
import time
|
33 |
time.sleep(5)
|
34 |
status = get_task_status(task_id)
|
35 |
-
|
|
|
|
|
36 |
except Exception as e:
|
37 |
log_submission(scene, prompt, model, user_ip, str(e))
|
38 |
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
|
39 |
-
|
40 |
-
|
41 |
-
raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
|
42 |
try:
|
43 |
-
for video_path in stream_simulation_results(result_folder, task_id):
|
44 |
if video_path:
|
45 |
yield video_path, history
|
46 |
except Exception as e:
|
47 |
log_submission(scene, prompt, model, user_ip, str(e))
|
48 |
raise gr.Error(f"流式输出过程中出错: {str(e)}")
|
|
|
|
|
49 |
status = get_task_status(task_id)
|
50 |
if status.get("status") == "completed":
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
new_entry = {
|
54 |
"timestamp": timestamp,
|
55 |
"scene": scene,
|
@@ -70,11 +93,8 @@ def run_simulation(scene, model, mode, prompt, history, request: gr.Request):
|
|
70 |
yield None, history
|
71 |
elif status.get("status") == "terminated":
|
72 |
log_submission(scene, prompt, model, user_ip, "terminated")
|
73 |
-
|
74 |
-
|
75 |
-
return f"⚠️ 任务 {task_id} 被终止,已生成部分结果", video_path, history
|
76 |
-
else:
|
77 |
-
return f"⚠️ 任务 {task_id} 被终止,未生成结果", None, history
|
78 |
else:
|
79 |
log_submission(scene, prompt, model, user_ip, "missing task's status from backend")
|
80 |
raise gr.Error("missing task's status from backend")
|
@@ -90,6 +110,9 @@ def cleanup_session(request: gr.Request):
|
|
90 |
requests.post(f"{BACKEND_URL}/predict/terminate/{task_id}", timeout=3)
|
91 |
except Exception:
|
92 |
pass
|
|
|
|
|
|
|
93 |
|
94 |
def record_access(request: gr.Request):
|
95 |
user_ip = request.client.host if request else "unknown"
|
|
|
6 |
from logging_utils import log_access, log_submission, is_request_allowed
|
7 |
from simulation import stream_simulation_results, convert_to_h264
|
8 |
from ui_components import update_history_display, update_scene_display, update_log_display, get_scene_instruction
|
9 |
+
from oss_utils import download_oss_file, get_user_tmp_dir, cleanup_user_tmp_dir, oss_file_exists
|
10 |
import os
|
11 |
from datetime import datetime
|
12 |
|
|
|
17 |
scene_desc = SCENE_CONFIGS.get(scene, {}).get("description", scene)
|
18 |
user_ip = request.client.host if request else "unknown"
|
19 |
session_id = request.session_hash
|
20 |
+
|
21 |
if not is_request_allowed(user_ip):
|
22 |
log_submission(scene, prompt, model, user_ip, "IP blocked temporarily")
|
23 |
raise gr.Error("Too many requests from this IP. Please wait and try again one minute later.")
|
24 |
+
|
25 |
+
# 提交任务到后端
|
26 |
submission_result = submit_to_backend(scene, prompt, mode, model, user_ip)
|
27 |
if submission_result.get("status") != "pending":
|
28 |
log_submission(scene, prompt, model, user_ip, "Submission failed")
|
29 |
raise gr.Error(f"Submission failed: {submission_result.get('message', 'unknown issue')}")
|
30 |
+
|
31 |
try:
|
32 |
task_id = submission_result["task_id"]
|
33 |
SESSION_TASKS[session_id] = task_id
|
34 |
gr.Info(f"Simulation started, task_id: {task_id}")
|
35 |
+
|
36 |
import time
|
37 |
time.sleep(5)
|
38 |
status = get_task_status(task_id)
|
39 |
+
# OSS上的结果文件夹路径,不再检查本地路径是否存在
|
40 |
+
result_folder = status.get("result", f"gradio_demo/tasks/{task_id}")
|
41 |
+
|
42 |
except Exception as e:
|
43 |
log_submission(scene, prompt, model, user_ip, str(e))
|
44 |
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
|
45 |
+
|
46 |
+
# 流式输出视频片段(从OSS读取)
|
|
|
47 |
try:
|
48 |
+
for video_path in stream_simulation_results(result_folder, task_id, request):
|
49 |
if video_path:
|
50 |
yield video_path, history
|
51 |
except Exception as e:
|
52 |
log_submission(scene, prompt, model, user_ip, str(e))
|
53 |
raise gr.Error(f"流式输出过程中出错: {str(e)}")
|
54 |
+
|
55 |
+
# 获取最终任务状态
|
56 |
status = get_task_status(task_id)
|
57 |
if status.get("status") == "completed":
|
58 |
+
# 从OSS下载最终视频文件
|
59 |
+
oss_video_path = os.path.join(result_folder, "output.mp4") # 或者根据实际的视频文件名调整
|
60 |
+
user_dir = get_user_tmp_dir(session_id)
|
61 |
+
local_video_path = os.path.join(user_dir, task_id, "output.mp4")
|
62 |
+
|
63 |
+
try:
|
64 |
+
# 尝试下载视频文件
|
65 |
+
if oss_file_exists(oss_video_path):
|
66 |
+
download_oss_file(oss_video_path, local_video_path)
|
67 |
+
video_path = convert_to_h264(local_video_path)
|
68 |
+
else:
|
69 |
+
# 如果OSS上没有最终视频,使用最后一个片段
|
70 |
+
gr.Info("Final video not found in OSS, using last segment")
|
71 |
+
video_path = None
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Error downloading final video: {e}")
|
75 |
+
video_path = None
|
76 |
new_entry = {
|
77 |
"timestamp": timestamp,
|
78 |
"scene": scene,
|
|
|
93 |
yield None, history
|
94 |
elif status.get("status") == "terminated":
|
95 |
log_submission(scene, prompt, model, user_ip, "terminated")
|
96 |
+
# 对于终止的任务,不再检查本地文件
|
97 |
+
yield None, history
|
|
|
|
|
|
|
98 |
else:
|
99 |
log_submission(scene, prompt, model, user_ip, "missing task's status from backend")
|
100 |
raise gr.Error("missing task's status from backend")
|
|
|
110 |
requests.post(f"{BACKEND_URL}/predict/terminate/{task_id}", timeout=3)
|
111 |
except Exception:
|
112 |
pass
|
113 |
+
|
114 |
+
# 清理用户临时目录
|
115 |
+
cleanup_user_tmp_dir(session_id)
|
116 |
|
117 |
def record_access(request: gr.Request):
|
118 |
user_ip = request.client.host if request else "unknown"
|
config.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# 配置相关:API、场景等
|
3 |
import os
|
4 |
|
5 |
-
BACKEND_URL = os.getenv("BACKEND_URL", "
|
6 |
API_ENDPOINTS = {
|
7 |
"submit_task": f"{BACKEND_URL}/predict/video",
|
8 |
"query_status": f"{BACKEND_URL}/predict/task",
|
|
|
2 |
# 配置相关:API、场景等
|
3 |
import os
|
4 |
|
5 |
+
BACKEND_URL = os.getenv("BACKEND_URL", "47.95.6.204:51001")
|
6 |
API_ENDPOINTS = {
|
7 |
"submit_task": f"{BACKEND_URL}/predict/video",
|
8 |
"query_status": f"{BACKEND_URL}/predict/task",
|
oss_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# oss_utils.py
|
2 |
+
# OSS相关工具函数
|
3 |
+
import os
|
4 |
+
import oss2
|
5 |
+
from typing import List
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
# OSS配置
|
9 |
+
OSS_CONFIG = {
|
10 |
+
"access_key_id": os.getenv("OSS_ACCESS_KEY_ID"),
|
11 |
+
"access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET"),
|
12 |
+
"endpoint": os.getenv("OSS_ENDPOINT"),
|
13 |
+
"bucket_name": os.getenv("OSS_BUCKET_NAME")
|
14 |
+
}
|
15 |
+
|
16 |
+
# 初始化OSS客户端
|
17 |
+
auth = oss2.Auth(OSS_CONFIG["access_key_id"], OSS_CONFIG["access_key_secret"])
|
18 |
+
bucket = oss2.Bucket(auth, OSS_CONFIG["endpoint"], OSS_CONFIG["bucket_name"])
|
19 |
+
|
20 |
+
# 临时文件根目录
|
21 |
+
TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
22 |
+
os.makedirs(TMP_ROOT, exist_ok=True)
|
23 |
+
|
24 |
+
def list_oss_files(folder_path: str) -> List[str]:
|
25 |
+
"""列出OSS文件夹中的所有文件"""
|
26 |
+
files = []
|
27 |
+
try:
|
28 |
+
for obj in oss2.ObjectIterator(bucket, prefix=folder_path):
|
29 |
+
if not obj.key.endswith('/'): # 排除目录本身
|
30 |
+
files.append(obj.key)
|
31 |
+
return sorted(files, key=lambda x: os.path.splitext(x)[0])
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Error listing OSS files: {str(e)}")
|
34 |
+
return []
|
35 |
+
|
36 |
+
def download_oss_file(oss_path: str, local_path: str):
|
37 |
+
"""从OSS下载文件到本地"""
|
38 |
+
try:
|
39 |
+
# 确保本地目录存在
|
40 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
41 |
+
bucket.get_object_to_file(oss_path, local_path)
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error downloading file {oss_path}: {str(e)}")
|
44 |
+
raise
|
45 |
+
|
46 |
+
def oss_file_exists(oss_path: str) -> bool:
|
47 |
+
"""检查OSS文件是否存在"""
|
48 |
+
try:
|
49 |
+
return bucket.object_exists(oss_path)
|
50 |
+
except Exception as e:
|
51 |
+
print(f"Error checking if file exists in OSS: {str(e)}")
|
52 |
+
return False
|
53 |
+
|
54 |
+
def get_user_tmp_dir(session_hash: str) -> str:
|
55 |
+
"""获取用户临时目录"""
|
56 |
+
user_dir = os.path.join(TMP_ROOT, str(session_hash))
|
57 |
+
os.makedirs(user_dir, exist_ok=True)
|
58 |
+
return user_dir
|
59 |
+
|
60 |
+
def cleanup_user_tmp_dir(session_hash: str):
|
61 |
+
"""清理用户临时目录"""
|
62 |
+
user_dir = os.path.join(TMP_ROOT, str(session_hash))
|
63 |
+
if os.path.exists(user_dir):
|
64 |
+
shutil.rmtree(user_dir)
|
requirements.txt
CHANGED
@@ -2,3 +2,4 @@ gradio>=4.0.0
|
|
2 |
requests>=2.28.0
|
3 |
opencv-python>=4.6.0
|
4 |
numpy>=1.21.0
|
|
|
|
2 |
requests>=2.28.0
|
3 |
opencv-python>=4.6.0
|
4 |
numpy>=1.21.0
|
5 |
+
oss2>=2.15.0
|
simulation.py
CHANGED
@@ -8,46 +8,94 @@ import numpy as np
|
|
8 |
from typing import List
|
9 |
import gradio as gr
|
10 |
from backend_api import get_task_status
|
|
|
11 |
|
12 |
-
def stream_simulation_results(result_folder: str, task_id: str, fps: int = 6):
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
frame_buffer: List[np.ndarray] = []
|
16 |
-
frames_per_segment = fps * 2
|
17 |
processed_files = set()
|
18 |
width, height = 0, 0
|
19 |
last_status_check = 0
|
20 |
-
status_check_interval = 5
|
21 |
max_time = 240
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
while max_time > 0:
|
23 |
max_time -= 1
|
24 |
current_time = time.time()
|
|
|
|
|
25 |
if current_time - last_status_check > status_check_interval:
|
26 |
status = get_task_status(task_id)
|
|
|
27 |
if status.get("status") == "completed":
|
28 |
-
|
|
|
29 |
if frame_buffer:
|
30 |
-
yield create_video_segment(frame_buffer, fps, width, height)
|
31 |
break
|
32 |
elif status.get("status") == "failed":
|
33 |
raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
|
34 |
elif status.get("status") == "terminated":
|
35 |
break
|
36 |
last_status_check = current_time
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
processed_files.add(filename)
|
52 |
has_new_frames = True
|
53 |
except Exception:
|
@@ -60,41 +108,73 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 6):
|
|
60 |
if max_time <= 0:
|
61 |
raise gr.Error("timeout 240s")
|
62 |
|
63 |
-
def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int) -> str:
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
67 |
out = cv2.VideoWriter(segment_name, fourcc, fps, (width, height))
|
|
|
68 |
for frame in frames:
|
69 |
out.write(frame)
|
70 |
out.release()
|
|
|
71 |
return segment_name
|
72 |
|
73 |
-
def
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def convert_to_h264(video_path):
|
|
|
|
|
|
|
|
|
90 |
import shutil
|
|
|
|
|
91 |
base, ext = os.path.splitext(video_path)
|
92 |
video_path_h264 = f"{base}_h264.mp4"
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if ffmpeg_bin is None:
|
97 |
raise RuntimeError("❌ 找不到 ffmpeg,请确保其已安装并在 PATH 中")
|
|
|
98 |
ffmpeg_cmd = [
|
99 |
ffmpeg_bin,
|
100 |
"-i", video_path,
|
@@ -105,11 +185,13 @@ def convert_to_h264(video_path):
|
|
105 |
"-movflags", "+faststart",
|
106 |
video_path_h264
|
107 |
]
|
108 |
-
|
109 |
try:
|
110 |
result = subprocess.run(ffmpeg_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
111 |
if not os.path.exists(video_path_h264):
|
112 |
raise FileNotFoundError(f"⚠️ H.264 文件未生成: {video_path_h264}")
|
113 |
return video_path_h264
|
|
|
|
|
114 |
except Exception as e:
|
115 |
-
raise
|
|
|
8 |
from typing import List
|
9 |
import gradio as gr
|
10 |
from backend_api import get_task_status
|
11 |
+
from oss_utils import list_oss_files, download_oss_file, get_user_tmp_dir
|
12 |
|
13 |
+
def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 6):
|
14 |
+
"""
|
15 |
+
流式输出仿真结果,从OSS读取图片
|
16 |
+
|
17 |
+
参数:
|
18 |
+
result_folder: OSS上包含生成图片的文件夹路径
|
19 |
+
task_id: 后端任务ID用于状态查询
|
20 |
+
request: Gradio请求对象
|
21 |
+
fps: 输出视频的帧率
|
22 |
+
|
23 |
+
生成:
|
24 |
+
生成的视频文件路径 (分段输出)
|
25 |
+
"""
|
26 |
+
# 初始化变量
|
27 |
+
image_folder = os.path.join(result_folder, "images")
|
28 |
frame_buffer: List[np.ndarray] = []
|
29 |
+
frames_per_segment = fps * 2 # 每2秒输出一段
|
30 |
processed_files = set()
|
31 |
width, height = 0, 0
|
32 |
last_status_check = 0
|
33 |
+
status_check_interval = 5 # 每5秒检查一次后端状态
|
34 |
max_time = 240
|
35 |
+
|
36 |
+
# 创建临时目录存储下载的图片
|
37 |
+
user_dir = get_user_tmp_dir(request.session_hash)
|
38 |
+
local_image_dir = os.path.join(user_dir, task_id, "images")
|
39 |
+
os.makedirs(local_image_dir, exist_ok=True)
|
40 |
+
|
41 |
while max_time > 0:
|
42 |
max_time -= 1
|
43 |
current_time = time.time()
|
44 |
+
|
45 |
+
# 定期检查后端状态
|
46 |
if current_time - last_status_check > status_check_interval:
|
47 |
status = get_task_status(task_id)
|
48 |
+
print(f"Session {request.session_hash}, status: {status}")
|
49 |
if status.get("status") == "completed":
|
50 |
+
# 确保处理完所有已生成的图片
|
51 |
+
process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
|
52 |
if frame_buffer:
|
53 |
+
yield create_video_segment(frame_buffer, fps, width, height, request)
|
54 |
break
|
55 |
elif status.get("status") == "failed":
|
56 |
raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
|
57 |
elif status.get("status") == "terminated":
|
58 |
break
|
59 |
last_status_check = current_time
|
60 |
+
|
61 |
+
# 从OSS获取文件列表
|
62 |
+
try:
|
63 |
+
oss_files = list_oss_files(image_folder)
|
64 |
+
new_files = [f for f in oss_files if f not in processed_files]
|
65 |
+
has_new_frames = False
|
66 |
+
|
67 |
+
for oss_path in new_files:
|
68 |
+
try:
|
69 |
+
# 下载文件到本地
|
70 |
+
filename = os.path.basename(oss_path)
|
71 |
+
local_path = os.path.join(local_image_dir, filename)
|
72 |
+
download_oss_file(oss_path, local_path)
|
73 |
+
|
74 |
+
# 读取图片
|
75 |
+
frame = cv2.imread(local_path)
|
76 |
+
if frame is not None:
|
77 |
+
if width == 0: # 第一次获取图像尺寸
|
78 |
+
height, width = frame.shape[:2]
|
79 |
+
|
80 |
+
frame_buffer.append(frame)
|
81 |
+
processed_files.add(oss_path)
|
82 |
+
has_new_frames = True
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error processing {oss_path}: {e}")
|
85 |
+
|
86 |
+
# 如果有新帧且积累够指定帧数,输出视频片段
|
87 |
+
if has_new_frames and len(frame_buffer) >= frames_per_segment:
|
88 |
+
segment_frames = frame_buffer[:frames_per_segment]
|
89 |
+
frame_buffer = frame_buffer[frames_per_segment:]
|
90 |
+
yield create_video_segment(segment_frames, fps, width, height, request)
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Error accessing OSS: {e}")
|
94 |
+
|
95 |
+
time.sleep(1) # 避免过于频繁检查
|
96 |
+
|
97 |
+
if max_time <= 0:
|
98 |
+
raise gr.Error("timeout 240s")
|
99 |
processed_files.add(filename)
|
100 |
has_new_frames = True
|
101 |
except Exception:
|
|
|
108 |
if max_time <= 0:
|
109 |
raise gr.Error("timeout 240s")
|
110 |
|
111 |
+
def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, request: gr.Request) -> str:
|
112 |
+
"""创建视频片段"""
|
113 |
+
user_dir = get_user_tmp_dir(request.session_hash)
|
114 |
+
video_chunk_dir = os.path.join(user_dir, "video_chunks")
|
115 |
+
os.makedirs(video_chunk_dir, exist_ok=True)
|
116 |
+
|
117 |
+
segment_name = os.path.join(video_chunk_dir, f"output_{uuid.uuid4()}.mp4")
|
118 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
119 |
out = cv2.VideoWriter(segment_name, fourcc, fps, (width, height))
|
120 |
+
|
121 |
for frame in frames:
|
122 |
out.write(frame)
|
123 |
out.release()
|
124 |
+
|
125 |
return segment_name
|
126 |
|
127 |
+
def process_remaining_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
|
128 |
+
"""处理OSS上剩余的图片"""
|
129 |
+
try:
|
130 |
+
oss_files = list_oss_files(oss_folder)
|
131 |
+
new_files = [f for f in oss_files if f not in processed_files]
|
132 |
+
|
133 |
+
for oss_path in new_files:
|
134 |
+
try:
|
135 |
+
# 下载文件到本地
|
136 |
+
filename = os.path.basename(oss_path)
|
137 |
+
local_path = os.path.join(local_dir, filename)
|
138 |
+
download_oss_file(oss_path, local_path)
|
139 |
+
|
140 |
+
# 读取图片
|
141 |
+
frame = cv2.imread(local_path)
|
142 |
+
if frame is not None:
|
143 |
+
frame_buffer.append(frame)
|
144 |
+
processed_files.add(oss_path)
|
145 |
+
except Exception as e:
|
146 |
+
print(f"Error processing remaining {oss_path}: {e}")
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Error accessing OSS for remaining files: {e}")
|
149 |
|
150 |
def convert_to_h264(video_path):
|
151 |
+
"""
|
152 |
+
将视频转换为 H.264 编码的 MP4 格式
|
153 |
+
生成新文件路径在原路径基础上添加 _h264 后缀
|
154 |
+
"""
|
155 |
import shutil
|
156 |
+
import subprocess
|
157 |
+
|
158 |
base, ext = os.path.splitext(video_path)
|
159 |
video_path_h264 = f"{base}_h264.mp4"
|
160 |
+
|
161 |
+
# 查找ffmpeg
|
162 |
+
ffmpeg_bin = shutil.which("ffmpeg")
|
163 |
+
if ffmpeg_bin is None:
|
164 |
+
# 尝试常见的安装路径
|
165 |
+
possible_paths = [
|
166 |
+
"/root/anaconda3/envs/gradio/bin/ffmpeg",
|
167 |
+
"/usr/bin/ffmpeg",
|
168 |
+
"/usr/local/bin/ffmpeg"
|
169 |
+
]
|
170 |
+
for path in possible_paths:
|
171 |
+
if os.path.exists(path):
|
172 |
+
ffmpeg_bin = path
|
173 |
+
break
|
174 |
+
|
175 |
if ffmpeg_bin is None:
|
176 |
raise RuntimeError("❌ 找不到 ffmpeg,请确保其已安装并在 PATH 中")
|
177 |
+
|
178 |
ffmpeg_cmd = [
|
179 |
ffmpeg_bin,
|
180 |
"-i", video_path,
|
|
|
185 |
"-movflags", "+faststart",
|
186 |
video_path_h264
|
187 |
]
|
188 |
+
|
189 |
try:
|
190 |
result = subprocess.run(ffmpeg_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
191 |
if not os.path.exists(video_path_h264):
|
192 |
raise FileNotFoundError(f"⚠️ H.264 文件未生成: {video_path_h264}")
|
193 |
return video_path_h264
|
194 |
+
except subprocess.CalledProcessError as e:
|
195 |
+
raise gr.Error(f"FFmpeg 转换失败: {e.stderr}")
|
196 |
except Exception as e:
|
197 |
+
raise gr.Error(f"转换过程中发生错误: {str(e)}")
|