zijian2022 commited on
Commit
d85386b
·
verified ·
1 Parent(s): b717ef8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -6
app.py CHANGED
@@ -7,9 +7,111 @@ import numpy as np
7
  import cv2
8
  import base64
9
  from scipy.ndimage import gaussian_filter1d
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # ------------------ 加载数据 ------------------
12
- df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
 
13
  columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
14
  timestamps = df["timestamp"].values
15
  delta_t = np.diff(timestamps)
@@ -17,9 +119,10 @@ time_for_plot = timestamps[1:]
17
  action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
18
 
19
  # ------------------ 视频路径 ------------------
20
- video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
21
- video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
22
-
 
23
  # ------------------ Dash 初始化 ------------------
24
  app = dash.Dash(__name__)
25
  server = app.server
@@ -374,5 +477,5 @@ def update_video_frames(*args):
374
  return [no_update]*12 + [hover_state]
375
 
376
  # ------------------ 启动应用 ------------------
377
- if __name__ == '__main__':
378
- app.run_server(host="0.0.0.0", port=7860, debug=False)
 
7
  import cv2
8
  import base64
9
  from scipy.ndimage import gaussian_filter1d
10
+ import requests
11
+ import json
12
+ import tempfile
13
+ import os
14
+ from pathlib import Path
15
+ from typing import Tuple, Optional
16
+ from urllib.parse import urljoin
17
+ # ------------------ 下载数据 ------------------
18
 
19
+ class RemoteDatasetLoader:
20
+ """从 Hugging Face Hub 远程加载数据集的类"""
21
+
22
+ def __init__(self, repo_id: str, timeout: int = 30):
23
+ self.repo_id = repo_id
24
+ self.timeout = timeout
25
+ self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
26
+
27
+ def _get_dataset_info(self) -> dict:
28
+ info_url = urljoin(self.base_url, "meta/info.json")
29
+ response = requests.get(info_url, timeout=self.timeout)
30
+ response.raise_for_status()
31
+ return response.json()
32
+
33
+ def _get_episode_info(self, episode_id: int) -> dict:
34
+ episodes_url = urljoin(self.base_url, "meta/episodes.jsonl")
35
+ response = requests.get(episodes_url, timeout=self.timeout)
36
+ response.raise_for_status()
37
+ episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()]
38
+ for episode in episodes:
39
+ if episode.get("episode_index") == episode_id:
40
+ return episode
41
+ raise ValueError(f"Episode {episode_id} not found")
42
+
43
+ def _download_video(self, video_url: str, save_path: str) -> str:
44
+ response = requests.get(video_url, timeout=self.timeout, stream=True)
45
+ response.raise_for_status()
46
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
47
+ with open(save_path, 'wb') as f:
48
+ for chunk in response.iter_content(chunk_size=8192):
49
+ f.write(chunk)
50
+ return save_path
51
+
52
+ def load_episode_data(self, episode_id: int,
53
+ video_keys: Optional[list] = None,
54
+ download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
55
+ dataset_info = self._get_dataset_info()
56
+ episode_info = self._get_episode_info(episode_id)
57
+
58
+ if download_dir is None:
59
+ download_dir = tempfile.mkdtemp(prefix="lerobot_videos_")
60
+
61
+ if video_keys is None:
62
+ video_keys = [key for key, feature in dataset_info["features"].items()
63
+ if feature["dtype"] == "video"]
64
+
65
+ video_keys = video_keys[:2]
66
+ video_paths = []
67
+ chunks_size = dataset_info.get("chunks_size", 1000)
68
+
69
+ for i, video_key in enumerate(video_keys):
70
+ video_url = self.base_url + dataset_info["video_path"].format(
71
+ episode_chunk=episode_id // chunks_size,
72
+ video_key=video_key,
73
+ episode_index=episode_id
74
+ )
75
+ video_filename = f"episode_{episode_id}_{video_key}.mp4"
76
+ local_path = os.path.join(download_dir, video_filename)
77
+ try:
78
+ downloaded_path = self._download_video(video_url, local_path)
79
+ video_paths.append(downloaded_path)
80
+ print(f"Downloaded video {i+1}: {downloaded_path}")
81
+ except Exception as e:
82
+ print(f"Failed to download video {video_key}: {e}")
83
+ video_paths.append(video_url)
84
+
85
+ data_url = self.base_url + dataset_info["data_path"].format(
86
+ episode_chunk=episode_id // chunks_size,
87
+ episode_index=episode_id
88
+ )
89
+ try:
90
+ df = pd.read_parquet(data_url)
91
+ print(f"Loaded data with {len(df)} rows and {len(df.columns)} columns")
92
+ except Exception as e:
93
+ print(f"Failed to load data: {e}")
94
+ df = pd.DataFrame()
95
+
96
+ return video_paths, df
97
+
98
+
99
+ def load_remote_dataset(repo_id: str,
100
+ episode_id: int = 0,
101
+ video_keys: Optional[list] = None,
102
+ download_dir: Optional[str] = None) -> Tuple[list, pd.DataFrame]:
103
+ loader = RemoteDatasetLoader(repo_id)
104
+ return loader.load_episode_data(episode_id, video_keys, download_dir)
105
+
106
+
107
+ video_paths, data_df = load_remote_dataset(
108
+ repo_id="zijian2022/sortingtest",
109
+ episode_id=0,
110
+ download_dir="./downloaded_videos"
111
+ )
112
  # ------------------ 加载数据 ------------------
113
+ #df = pd.read_parquet("./data/clean_data/uni_boxing_object_vfm/data/chunk-000/episode_000000.parquet")
114
+ df = data_df
115
  columns = ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"]
116
  timestamps = df["timestamp"].values
117
  delta_t = np.diff(timestamps)
 
119
  action_df = pd.DataFrame(df["action"].tolist(), columns=columns)
120
 
121
  # ------------------ 视频路径 ------------------
122
+ #video_path_1 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.laptop/episode_000000.mp4"
123
+ #video_path_2 = "./data/clean_data/uni_boxing_object_vfm/videos/chunk-000/observation.images.phone/episode_000000.mp4"
124
+ video_path_1 = video_paths[0]
125
+ video_path_2 = video_paths[1]
126
  # ------------------ Dash 初始化 ------------------
127
  app = dash.Dash(__name__)
128
  server = app.server
 
477
  return [no_update]*12 + [hover_state]
478
 
479
  # ------------------ 启动应用 ------------------
480
+ if __name__ == "__main__":
481
+ app.run(debug=True)