|
import gradio as gr |
|
import requests |
|
import json |
|
import os |
|
import subprocess |
|
import uuid |
|
import time |
|
import cv2 |
|
from typing import Optional, List |
|
import numpy as np |
|
from datetime import datetime, timedelta |
|
from collections import defaultdict |
|
import shutil |
|
|
|
|
|
|
|
from app_utils import ( |
|
TMP_ROOT, |
|
) |
|
|
|
|
|
BACKEND_URL = os.getenv("BACKEND_URL", "http://47.95.6.204:51001/") |
|
API_ENDPOINTS = { |
|
"submit_task": f"{BACKEND_URL}/predict/video", |
|
"query_status": f"{BACKEND_URL}/predict/task", |
|
"terminate_task": f"{BACKEND_URL}/predict/terminate" |
|
} |
|
|
|
|
|
SCENE_CONFIGS = { |
|
"scene_1": { |
|
"description": "scene_1", |
|
"objects": ["milk carton", "ceramic bowl", "mug"], |
|
"preview_image": "assets/scene_1.png" |
|
}, |
|
} |
|
|
|
|
|
MODEL_CHOICES = [ |
|
"gr1", |
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
SESSION_TASKS = {} |
|
IP_REQUEST_RECORDS = defaultdict(list) |
|
IP_LIMIT = 5 |
|
|
|
def is_request_allowed(ip: str) -> bool: |
|
now = datetime.now() |
|
IP_REQUEST_RECORDS[ip] = [t for t in IP_REQUEST_RECORDS[ip] if now - t < timedelta(minutes=1)] |
|
if len(IP_REQUEST_RECORDS[ip]) < IP_LIMIT: |
|
IP_REQUEST_RECORDS[ip].append(now) |
|
return True |
|
return False |
|
|
|
|
|
|
|
|
|
LOG_DIR = "/opt/gradio-frontend/logs" |
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
ACCESS_LOG = os.path.join(LOG_DIR, "access.log") |
|
SUBMISSION_LOG = os.path.join(LOG_DIR, "submissions.log") |
|
|
|
def log_access(user_ip: str = None, user_agent: str = None): |
|
"""记录用户访问日志""" |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
log_entry = { |
|
"timestamp": timestamp, |
|
"type": "access", |
|
"user_ip": user_ip or "unknown", |
|
"user_agent": user_agent or "unknown" |
|
} |
|
|
|
with open(ACCESS_LOG, "a") as f: |
|
f.write(json.dumps(log_entry) + "\n") |
|
|
|
def log_submission(scene: str, prompt: str, model: str, max_step: int, user: str = "anonymous", res: str = "unknown"): |
|
"""记录用户提交日志""" |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
log_entry = { |
|
"timestamp": timestamp, |
|
"type": "submission", |
|
"user": user, |
|
"scene": scene, |
|
"prompt": prompt, |
|
"model": model, |
|
"max_step": str(max_step), |
|
"res": res |
|
} |
|
|
|
with open(SUBMISSION_LOG, "a") as f: |
|
f.write(json.dumps(log_entry) + "\n") |
|
|
|
|
|
def record_access(request: gr.Request): |
|
user_ip = request.client.host if request else "unknown" |
|
user_agent = request.headers.get("user-agent", "unknown") |
|
log_access(user_ip, user_agent) |
|
return update_log_display() |
|
|
|
def read_logs(log_type: str = "all", max_entries: int = 50) -> list: |
|
"""读取日志文件""" |
|
logs = [] |
|
|
|
if log_type in ["all", "access"]: |
|
try: |
|
with open(ACCESS_LOG, "r") as f: |
|
for line in f: |
|
logs.append(json.loads(line.strip())) |
|
except FileNotFoundError: |
|
pass |
|
|
|
if log_type in ["all", "submission"]: |
|
try: |
|
with open(SUBMISSION_LOG, "r") as f: |
|
for line in f: |
|
logs.append(json.loads(line.strip())) |
|
except FileNotFoundError: |
|
pass |
|
|
|
|
|
logs.sort(key=lambda x: x["timestamp"], reverse=True) |
|
return logs[:max_entries] |
|
|
|
def format_logs_for_display(logs: list) -> str: |
|
"""格式化日志用于显示""" |
|
if not logs: |
|
return "暂无日志记录" |
|
|
|
markdown = "### 系统访问日志\n\n" |
|
markdown += "| 时间 | 类型 | 用户/IP | 详细信息 |\n" |
|
markdown += "|------|------|---------|----------|\n" |
|
|
|
for log in logs: |
|
timestamp = log.get("timestamp", "unknown") |
|
log_type = "访问" if log.get("type") == "access" else "提交" |
|
|
|
if log_type == "访问": |
|
user = log.get("user_ip", "unknown") |
|
details = f"User-Agent: {log.get('user_agent', 'unknown')}" |
|
else: |
|
user = log.get("user", "anonymous") |
|
result = log.get('res', 'unknown') |
|
if result != "success": |
|
if len(result) > 40: |
|
result = f"{result[:20]}...{result[-20:]}" |
|
details = f"场景: {log.get('scene', 'unknown')}, 指令: {log.get('prompt', '')}, 模型: {log.get('model', 'unknown')}, max step: {log.get('max_step', '300')}, result: {result}" |
|
|
|
markdown += f"| {timestamp} | {log_type} | {user} | {details} |\n" |
|
|
|
return markdown |
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30): |
|
""" |
|
流式输出仿真结果,同时监控图片文件夹和后端任务状态 |
|
|
|
参数: |
|
result_folder: 包含生成图片的文件夹路径 |
|
task_id: 后端任务ID用于状态查询 |
|
fps: 输出视频的帧率 |
|
|
|
生成: |
|
生成的视频文件路径 (分段输出) |
|
""" |
|
|
|
result_folder = os.path.join(result_folder, "image") |
|
os.makedirs(result_folder, exist_ok=True) |
|
frame_buffer: List[np.ndarray] = [] |
|
frames_per_segment = fps * 2 |
|
processed_files = set() |
|
width, height = 0, 0 |
|
last_status_check = 0 |
|
status_check_interval = 5 |
|
max_time = 240 |
|
|
|
while max_time > 0: |
|
max_time -= 1 |
|
current_time = time.time() |
|
|
|
|
|
if current_time - last_status_check > status_check_interval: |
|
status = get_task_status(task_id) |
|
print("status: ", status) |
|
if status.get("status") == "completed": |
|
|
|
process_remaining_images(result_folder, processed_files, frame_buffer) |
|
if frame_buffer: |
|
yield create_video_segment(frame_buffer, fps, width, height) |
|
break |
|
elif status.get("status") == "failed": |
|
raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}") |
|
elif status.get("status") == "terminated": |
|
break |
|
last_status_check = current_time |
|
|
|
|
|
current_files = sorted( |
|
[f for f in os.listdir(result_folder) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg'))], |
|
key=lambda x: os.path.splitext(x)[0] |
|
) |
|
|
|
new_files = [f for f in current_files if f not in processed_files] |
|
has_new_frames = False |
|
|
|
for filename in new_files: |
|
try: |
|
img_path = os.path.join(result_folder, filename) |
|
frame = cv2.imread(img_path) |
|
if frame is not None: |
|
if width == 0: |
|
height, width = frame.shape[:2] |
|
|
|
frame_buffer.append(frame) |
|
processed_files.add(filename) |
|
has_new_frames = True |
|
except Exception as e: |
|
print(f"Error processing {filename}: {e}") |
|
|
|
|
|
if has_new_frames and len(frame_buffer) >= frames_per_segment: |
|
segment_frames = frame_buffer[:frames_per_segment] |
|
frame_buffer = frame_buffer[frames_per_segment:] |
|
yield create_video_segment(segment_frames, fps, width, height) |
|
|
|
time.sleep(1) |
|
|
|
if max_time <= 0: |
|
raise gr.Error("timeout 240s") |
|
|
|
def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height: int, req: gr.Request) -> str: |
|
"""创建视频片段""" |
|
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
|
os.makedirs(user_dir, exist_ok=True) |
|
video_chunk_path = os.path.join(user_dir, "tasks/video_chunk") |
|
os.makedirs(video_chunk_path, exist_ok=True) |
|
segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4") |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(segment_name, fourcc, fps, (width, height)) |
|
|
|
for frame in frames: |
|
out.write(frame) |
|
out.release() |
|
|
|
return segment_name |
|
|
|
def process_remaining_images(result_folder: str, processed_files: set, frame_buffer: List[np.ndarray]): |
|
"""处理剩余的图片""" |
|
current_files = sorted( |
|
[f for f in os.listdir(result_folder) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg'))], |
|
key=lambda x: os.path.splitext(x)[0] |
|
) |
|
|
|
new_files = [f for f in current_files if f not in processed_files] |
|
|
|
for filename in new_files: |
|
try: |
|
img_path = os.path.join(result_folder, filename) |
|
frame = cv2.imread(img_path) |
|
if frame is not None: |
|
frame_buffer.append(frame) |
|
processed_files.add(filename) |
|
except Exception as e: |
|
print(f"Error processing remaining {filename}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def submit_to_backend( |
|
scene: str, |
|
prompt: str, |
|
model: str, |
|
max_step: int, |
|
user: str = "Gradio-user", |
|
) -> dict: |
|
job_id = str(uuid.uuid4()) |
|
|
|
data = { |
|
"scene_type": scene, |
|
"instruction": prompt, |
|
"model_type": model, |
|
"max_step": str(max_step) |
|
} |
|
|
|
payload = { |
|
"user": user, |
|
"task": "robot_manipulation", |
|
"job_id": job_id, |
|
"data": json.dumps(data) |
|
} |
|
|
|
try: |
|
headers = {"Content-Type": "application/json"} |
|
response = requests.post( |
|
API_ENDPOINTS["submit_task"], |
|
json=payload, |
|
headers=headers, |
|
timeout=10 |
|
) |
|
return response.json() |
|
except Exception as e: |
|
return {"status": "error", "message": str(e)} |
|
|
|
def get_task_status(task_id: str) -> dict: |
|
""" |
|
查询任务状态 |
|
""" |
|
try: |
|
response = requests.get( |
|
f"{API_ENDPOINTS['query_status']}/{task_id}", |
|
timeout=5 |
|
) |
|
return response.json() |
|
except Exception as e: |
|
return {"status": "error", "message": str(e)} |
|
|
|
def terminate_task(task_id: str) -> Optional[dict]: |
|
""" |
|
终止任务 |
|
""" |
|
try: |
|
response = requests.post( |
|
f"{API_ENDPOINTS['terminate_task']}/{task_id}", |
|
timeout=3 |
|
) |
|
return response.json() |
|
except Exception as e: |
|
print(f"Error terminate task: {e}") |
|
return None |
|
|
|
def convert_to_h264(video_path): |
|
""" |
|
将视频转换为 H.264 编码的 MP4 格式 |
|
生成新文件路径在原路径基础上添加 _h264 后缀) |
|
""" |
|
base, ext = os.path.splitext(video_path) |
|
video_path_h264 = f"{base}_h264.mp4" |
|
|
|
try: |
|
|
|
ffmpeg_cmd = [ |
|
"ffmpeg", |
|
"-i", video_path, |
|
"-c:v", "libx264", |
|
"-preset", "slow", |
|
"-crf", "23", |
|
"-c:a", "aac", |
|
"-movflags", "+faststart", |
|
video_path_h264 |
|
] |
|
|
|
|
|
subprocess.run(ffmpeg_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
|
|
|
if not os.path.exists(video_path_h264): |
|
raise FileNotFoundError(f"H.264 编码文件未生成: {video_path_h264}") |
|
|
|
return video_path_h264 |
|
|
|
except subprocess.CalledProcessError as e: |
|
raise gr.Error(f"FFmpeg 转换失败: {e.stderr}") |
|
except Exception as e: |
|
raise gr.Error(f"转换过程中发生错误: {str(e)}") |
|
|
|
def run_simulation( |
|
scene: str, |
|
prompt: str, |
|
model: str, |
|
max_step: int, |
|
history: list, |
|
request: gr.Request |
|
): |
|
"""运行仿真并更新历史记录""" |
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
scene_desc = SCENE_CONFIGS.get(scene, {}).get("description", scene) |
|
|
|
|
|
user_ip = request.client.host if request else "unknown" |
|
session_id = request.session_hash |
|
|
|
if not is_request_allowed(user_ip): |
|
log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily") |
|
raise gr.Error("Too many requests from this IP. Please wait and try again one minute later.") |
|
|
|
|
|
submission_result = submit_to_backend(scene, prompt, model, max_step, user_ip) |
|
print("submission_result: ", submission_result) |
|
|
|
if submission_result.get("status") != "pending": |
|
log_submission(scene, prompt, model, max_step, user_ip, "Submission failed") |
|
raise gr.Error(f"Submission failed: {submission_result.get('message', 'unknown issue')}") |
|
|
|
try: |
|
task_id = submission_result["task_id"] |
|
SESSION_TASKS[session_id] = task_id |
|
|
|
gr.Info(f"Simulation started, task_id: {task_id}") |
|
time.sleep(5) |
|
|
|
status = get_task_status(task_id) |
|
print("first status: ", status) |
|
result_folder = status.get("result", "") |
|
except Exception as e: |
|
log_submission(scene, prompt, model, max_step, user_ip, str(e)) |
|
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}") |
|
|
|
|
|
if not os.path.exists(result_folder): |
|
log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist") |
|
raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}") |
|
|
|
|
|
|
|
try: |
|
for video_path in stream_simulation_results(result_folder, task_id): |
|
if video_path: |
|
yield video_path, history |
|
except Exception as e: |
|
log_submission(scene, prompt, model, max_step, user_ip, str(e)) |
|
raise gr.Error(f"Error while streaming: {str(e)}") |
|
|
|
|
|
status = get_task_status(task_id) |
|
print("status: ", status) |
|
if status.get("status") == "completed": |
|
video_path = os.path.join(status.get("result"), "manipulation.mp4") |
|
print("video_path: ", video_path) |
|
video_path = convert_to_h264(video_path) |
|
|
|
|
|
new_entry = { |
|
"timestamp": timestamp, |
|
"scene": scene, |
|
"model": model, |
|
"prompt": prompt, |
|
"max_step": max_step, |
|
"video_path": video_path, |
|
"task_id": task_id |
|
} |
|
|
|
|
|
updated_history = history + [new_entry] |
|
|
|
|
|
if len(updated_history) > 10: |
|
updated_history = updated_history[:10] |
|
|
|
print("updated_history:", updated_history) |
|
log_submission(scene, prompt, model, max_step, user_ip, "success") |
|
gr.Info("Simulation completed successfully!") |
|
yield None, updated_history |
|
|
|
elif status.get("status") == "failed": |
|
log_submission(scene, prompt, model, max_step, user_ip, status.get('result', 'backend error')) |
|
raise gr.Error(f"Task execution failed: {status.get('result', 'backend unknown issue')}") |
|
yield None, history |
|
|
|
elif status.get("status") == "terminated": |
|
log_submission(scene, prompt, model, max_step, user_ip, "user end terminated") |
|
yield None, history |
|
|
|
else: |
|
log_submission(scene, prompt, model, max_step, user_ip, "missing task's status from backend (Pending?)") |
|
raise gr.Error("missing task's status from backend (Pending?)") |
|
yield None, history |
|
|
|
|
|
|
|
|
|
|
|
def update_history_display(history: list) -> list: |
|
"""更新历史记录显示""" |
|
print("更新历史记录显示") |
|
updates = [] |
|
|
|
for i in range(10): |
|
if i < len(history): |
|
entry = history[i] |
|
updates.extend([ |
|
gr.update(visible=True), |
|
gr.update(visible=True, label=f"# {i+1} | {entry['scene']} | {entry['model']} | {entry['prompt']}", open=(i+1==len(history))), |
|
gr.update(value=entry['video_path'], visible=True, autoplay=False), |
|
gr.update(value=f"{entry['timestamp']}") |
|
]) |
|
else: |
|
updates.extend([ |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(value=None, visible=False), |
|
gr.update(value="") |
|
]) |
|
print("更新完成!") |
|
return updates |
|
|
|
def update_scene_display(scene: str) -> tuple[str, Optional[str]]: |
|
"""更新场景描述和预览图""" |
|
config = SCENE_CONFIGS.get(scene, {}) |
|
desc = config.get("description", "No description") |
|
objects = ", ".join(config.get("objects", [])) |
|
image = config.get("preview_image", None) |
|
|
|
markdown = f"**{desc}** \nObjects in this scene: {objects}" |
|
return markdown, image |
|
|
|
def update_log_display(): |
|
"""更新日志显示""" |
|
logs = read_logs() |
|
return format_logs_for_display(logs) |
|
|
|
|
|
|
|
|
|
def cleanup_session(request: gr.Request): |
|
session_id = request.session_hash |
|
task_id = SESSION_TASKS.pop(session_id, None) |
|
|
|
if task_id: |
|
try: |
|
status = get_task_status(task_id) |
|
print("clean up check status: ", status) |
|
if status.get("status") == "pending": |
|
res = terminate_task(task_id) |
|
if res.get("status") == "success": |
|
print(f"已终止任务 {task_id}") |
|
else: |
|
print(f"终止任务失败 {task_id}: {res.get('status', 'unknown issue')}") |
|
except Exception as e: |
|
print(f"终止任务失败 {task_id}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
header_html = """ |
|
<div style="display: flex; justify-content: space-between; align-items: center; width: 100%; margin-bottom: 20px; padding: 20px; background: linear-gradient(135deg, #528bdb 0%, #a7b5d0 100%); border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> |
|
<div style="display: flex; align-items: center;"> |
|
<img src="https://www.shlab.org.cn/static/img/index_14.685f6559.png" alt="Institution Logo" style="height: 60px; margin-right: 20px;"> |
|
<div> |
|
<h1 style="margin: 0; color: #2c3e50; font-weight: 600;">🤖 InternManip Model Inference Demo</h1> |
|
<p style="margin: 4px 0 0 0; color: #5d6d7e; font-size: 0.9em;">Model trained on InternManip framework</p> |
|
</div> |
|
</div> |
|
<div style="display: flex; gap: 15px; align-items: center;"> |
|
<a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'"> |
|
<img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px;"> |
|
</a> |
|
<a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'"> |
|
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;"> |
|
</a> |
|
<a href="https://huggingface.co/spaces/OpenRobotLab/InternNav-eval-demo" target="_blank"> |
|
<button style="padding: 8px 15px; background: #3498db; color: white; border: none; border-radius: 4px; cursor: pointer; font-weight: 500; transition: all 0.2s;" |
|
onmouseover="this.style.backgroundColor='#2980b9'; this.style.transform='scale(1.05)'" |
|
onmouseout="this.style.backgroundColor='#3498db'; this.style.transform='scale(1)'"> |
|
Go to InternNav Demo |
|
</button> |
|
</a> |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
#simulation-panel { |
|
border-radius: 8px; |
|
padding: 20px; |
|
background: #f9f9f9; |
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
|
} |
|
#result-panel { |
|
border-radius: 8px; |
|
padding: 20px; |
|
background: #f0f8ff; |
|
} |
|
.dark #simulation-panel { background: #2a2a2a; } |
|
.dark #result-panel { background: #1a2a3a; } |
|
|
|
.history-container { |
|
max-height: 600px; |
|
overflow-y: auto; |
|
margin-top: 20px; |
|
} |
|
|
|
.history-accordion { |
|
margin-bottom: 10px; |
|
} |
|
|
|
.logs-container { |
|
max-height: 500px; |
|
overflow-y: auto; |
|
margin-top: 20px; |
|
padding: 15px; |
|
background: #f5f5f5; |
|
border-radius: 8px; |
|
} |
|
|
|
.dark .logs-container { |
|
background: #2a2a2a; |
|
} |
|
|
|
.log-table { |
|
width: 100%; |
|
border-collapse: collapse; |
|
} |
|
|
|
.log-table th, .log-table td { |
|
padding: 8px 12px; |
|
border: 1px solid #ddd; |
|
text-align: left; |
|
} |
|
|
|
.dark .log-table th, .dark .log-table td { |
|
border-color: #444; |
|
} |
|
""" |
|
|
|
def start_session(req: gr.Request): |
|
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
|
os.makedirs(user_dir, exist_ok=True) |
|
|
|
|
|
def end_session(req: gr.Request): |
|
user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
|
shutil.rmtree(user_dir) |
|
|
|
|
|
|
|
with gr.Blocks(title="InternManip Model Inference Demo", css=custom_css) as demo: |
|
gr.HTML(header_html) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history_state = gr.State([]) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(elem_id="simulation-panel"): |
|
gr.Markdown("### Simulation Settings") |
|
|
|
|
|
scene_dropdown = gr.Dropdown( |
|
label="Choose a scene", |
|
choices=list(SCENE_CONFIGS.keys()), |
|
value="scene_1", |
|
interactive=True |
|
) |
|
|
|
|
|
scene_description = gr.Markdown("") |
|
scene_preview = gr.Image( |
|
label="Scene Preview", |
|
elem_classes=["scene-preview"], |
|
interactive=False |
|
) |
|
|
|
scene_dropdown.change( |
|
update_scene_display, |
|
inputs=scene_dropdown, |
|
outputs=[scene_description, scene_preview] |
|
) |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
label="Manipulation Prompt", |
|
value="Move the milk carton to the top of the ceramic bowl.", |
|
placeholder="Example: 'Move the milk carton to the top of the ceramic bowl.'", |
|
lines=2, |
|
max_lines=4 |
|
) |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
label="Chose a pretrained model", |
|
choices=MODEL_CHOICES, |
|
value=MODEL_CHOICES[0] |
|
) |
|
|
|
with gr.Accordion("Advance Settings", open=False): |
|
max_steps = gr.Slider( |
|
minimum=50, |
|
maximum=500, |
|
value=300, |
|
step=10, |
|
label="Max Steps" |
|
) |
|
|
|
|
|
submit_btn = gr.Button("Apply and Start Simulation", variant="primary") |
|
|
|
|
|
with gr.Column(elem_id="result-panel"): |
|
gr.Markdown("### Result") |
|
|
|
|
|
|
|
|
|
video_output = gr.Video( |
|
label="Live", |
|
interactive=False, |
|
format="mp4", |
|
autoplay=True, |
|
streaming=True |
|
) |
|
|
|
|
|
with gr.Column() as history_container: |
|
gr.Markdown("### History") |
|
gr.Markdown("#### History will be reset after refresh") |
|
|
|
|
|
history_slots = [] |
|
for i in range(10): |
|
with gr.Column(visible=False) as slot: |
|
with gr.Accordion(visible=False, open=False) as accordion: |
|
video = gr.Video(interactive=False) |
|
detail_md = gr.Markdown() |
|
history_slots.append((slot, accordion, video, detail_md)) |
|
|
|
|
|
with gr.Accordion("查看系统访问日志(DEV ONLY)", open=False): |
|
logs_display = gr.Markdown() |
|
refresh_logs_btn = gr.Button("刷新日志", variant="secondary") |
|
|
|
refresh_logs_btn.click( |
|
update_log_display, |
|
outputs=logs_display |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["scene_1", "Move the milk carton to the top of the ceramic bowl.", "gr1", 300], |
|
], |
|
inputs=[scene_dropdown, prompt_input, model_dropdown, max_steps], |
|
label="Examples" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=run_simulation, |
|
inputs=[scene_dropdown, prompt_input, model_dropdown, max_steps, history_state], |
|
outputs=[video_output, history_state], |
|
queue=True |
|
).then( |
|
fn=update_history_display, |
|
inputs=history_state, |
|
outputs=[comp for slot in history_slots for comp in slot], |
|
queue=True |
|
).then( |
|
fn=update_log_display, |
|
outputs=logs_display |
|
) |
|
|
|
|
|
demo.load( |
|
start_session |
|
).then( |
|
fn=lambda: update_scene_display("scene_1"), |
|
outputs=[scene_description, scene_preview] |
|
).then( |
|
fn=record_access, |
|
inputs=None, |
|
outputs=logs_display, |
|
queue=False |
|
).then( |
|
fn=update_log_display, |
|
outputs=logs_display |
|
) |
|
|
|
demo.queue(default_concurrency_limit=8) |
|
|
|
demo.unload(fn=cleanup_session).then(end_session) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |