Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,9 +7,111 @@ import numpy as np
|
|
7 |
import cv2
|
8 |
import base64
|
9 |
from scipy.ndimage import gaussian_filter1d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# ------------------ 加载数据 ------------------
|
12 |
-
df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
|
|
|
13 |
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
|
14 |
timestamps = df["timestamp"].values
|
15 |
delta_t = np.diff(timestamps)
|
@@ -17,9 +119,10 @@ time_for_plot = timestamps[1:]
|
|
17 |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
|
18 |
|
19 |
# ------------------ 视频路径 ------------------
|
20 |
-
video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
|
21 |
-
video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
|
22 |
-
|
|
|
23 |
# ------------------ Dash 初始化 ------------------
|
24 |
app = dash.Dash(__name__)
|
25 |
server = app.server
|
@@ -374,5 +477,5 @@ def update_video_frames(*args):
|
|
374 |
return [no_update]*12 + [hover_state]
|
375 |
|
376 |
# ------------------ 启动应用 ------------------
|
377 |
-
if __name__ ==
|
378 |
-
app.
|
|
|
7 |
import cv2
|
8 |
import base64
|
9 |
from scipy.ndimage import gaussian_filter1d
|
10 |
+
import requests
|
11 |
+
import json
|
12 |
+
import tempfile
|
13 |
+
import os
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import Tuple, Optional
|
16 |
+
from urllib.parse import urljoin
|
17 |
+
# ------------------ 下载数据 ------------------
|
18 |
|
19 |
+
class RemoteDatasetLoader:
|
20 |
+
"""从 Hugging Face Hub 远程加载数据集的类"""
|
21 |
+
|
22 |
+
def __init__(self, repo_id: str, timeout: int = 30):
|
23 |
+
self.repo_id = repo_id
|
24 |
+
self.timeout = timeout
|
25 |
+
self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
26 |
+
|
27 |
+
def _get_dataset_info(self) -> dict:
|
28 |
+
info_url = urljoin(self.base_url, "meta/info.json")
|
29 |
+
response = requests.get(info_url, timeout=self.timeout)
|
30 |
+
response.raise_for_status()
|
31 |
+
return response.json()
|
32 |
+
|
33 |
+
def _get_episode_info(self, episode_id: int) -> dict:
|
34 |
+
episodes_url = urljoin(self.base_url, "meta/episodes.jsonl")
|
35 |
+
response = requests.get(episodes_url, timeout=self.timeout)
|
36 |
+
response.raise_for_status()
|
37 |
+
episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
38 |
+
for episode in episodes:
|
39 |
+
if episode.get("episode_index") == episode_id:
|
40 |
+
return episode
|
41 |
+
raise ValueError(f"Episode {episode_id} not found")
|
42 |
+
|
43 |
+
def _download_video(self, video_url: str, save_path: str) -> str:
|
44 |
+
response = requests.get(video_url, timeout=self.timeout, stream=True)
|
45 |
+
response.raise_for_status()
|
46 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
47 |
+
with open(save_path, 'wb') as f:
|
48 |
+
for chunk in response.iter_content(chunk_size=8192):
|
49 |
+
f.write(chunk)
|
50 |
+
return save_path
|
51 |
+
|
52 |
+
def load_episode_data(self, episode_id: int,
|
53 |
+
video_keys: Optional[list] = None,
|
54 |
+
download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
|
55 |
+
dataset_info = self._get_dataset_info()
|
56 |
+
episode_info = self._get_episode_info(episode_id)
|
57 |
+
|
58 |
+
if download_dir is None:
|
59 |
+
download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
|
60 |
+
|
61 |
+
if video_keys is None:
|
62 |
+
video_keys = [key for key, feature in dataset_info["features"].items()
|
63 |
+
if feature["dtype"] == "video"]
|
64 |
+
|
65 |
+
video_keys = video_keys[:2]
|
66 |
+
video_paths = []
|
67 |
+
chunks_size = dataset_info.get("chunks_size", 1000)
|
68 |
+
|
69 |
+
for i, video_key in enumerate(video_keys):
|
70 |
+
video_url = self.base_url + dataset_info["video_path"].format(
|
71 |
+
episode_chunk=episode_id // chunks_size,
|
72 |
+
video_key=video_key,
|
73 |
+
episode_index=episode_id
|
74 |
+
)
|
75 |
+
video_filename = f"episode_{episode_id}_{video_key}.mp4"
|
76 |
+
local_path = os.path.join(download_dir, video_filename)
|
77 |
+
try:
|
78 |
+
downloaded_path = self._download_video(video_url, local_path)
|
79 |
+
video_paths.append(downloaded_path)
|
80 |
+
print(f"Downloaded video {i+1}: {downloaded_path}")
|
81 |
+
except Exception as e:
|
82 |
+
print(f"Failed to download video {video_key}: {e}")
|
83 |
+
video_paths.append(video_url)
|
84 |
+
|
85 |
+
data_url = self.base_url + dataset_info["data_path"].format(
|
86 |
+
episode_chunk=episode_id // chunks_size,
|
87 |
+
episode_index=episode_id
|
88 |
+
)
|
89 |
+
try:
|
90 |
+
df = pd.read_parquet(data_url)
|
91 |
+
print(f"Loaded data with {len(df)} rows and {len(df.columns)} columns")
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Failed to load data: {e}")
|
94 |
+
df = pd.DataFrame()
|
95 |
+
|
96 |
+
return video_paths, df
|
97 |
+
|
98 |
+
|
99 |
+
def load_remote_dataset(repo_id: str,
|
100 |
+
episode_id: int = 0,
|
101 |
+
video_keys: Optional[list] = None,
|
102 |
+
download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
|
103 |
+
loader = RemoteDatasetLoader(repo_id)
|
104 |
+
return loader.load_episode_data(episode_id, video_keys, download_dir)
|
105 |
+
|
106 |
+
|
107 |
+
video_paths, data_df = load_remote_dataset(
|
108 |
+
repo_id="zijian2022/sortingtest",
|
109 |
+
episode_id=0,
|
110 |
+
download_dir="./downloaded_videos"
|
111 |
+
)
|
112 |
# ------------------ 加载数据 ------------------
|
113 |
+
#df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
|
114 |
+
df = data_df
|
115 |
columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
|
116 |
timestamps = df["timestamp"].values
|
117 |
delta_t = np.diff(timestamps)
|
|
|
119 |
action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
|
120 |
|
121 |
# ------------------ 视频路径 ------------------
|
122 |
+
#video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
|
123 |
+
#video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
|
124 |
+
video_path_1 = video_paths[0]
|
125 |
+
video_path_2 = video_paths[1]
|
126 |
# ------------------ Dash 初始化 ------------------
|
127 |
app = dash.Dash(__name__)
|
128 |
server = app.server
|
|
|
477 |
return [no_update]*12 + [hover_state]
|
478 |
|
479 |
# ------------------ 启动应用 ------------------
|
480 |
+
if __name__ == "__main__":
|
481 |
+
app.run(debug=True)
|