Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import numpy as np | |
from tqdm import tqdm | |
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | |
from lerobot.common.datasets.utils import write_episode_stats | |
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: | |
ep_len = dataset.meta.episodes[episode_index]["length"] | |
sampled_indices = sample_indices(ep_len) | |
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) | |
video_frames = dataset._query_videos(query_timestamps, episode_index) | |
return video_frames[ft_key].numpy() | |
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): | |
ep_start_idx = dataset.episode_data_index["from"][ep_idx] | |
ep_end_idx = dataset.episode_data_index["to"][ep_idx] | |
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) | |
ep_stats = {} | |
for key, ft in dataset.features.items(): | |
if ft["dtype"] == "video": | |
# We sample only for videos | |
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key) | |
else: | |
ep_ft_data = np.array(ep_data[key]) | |
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 | |
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 | |
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) | |
if ft["dtype"] in ["image", "video"]: # remove batch dim | |
ep_stats[key] = { | |
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() | |
} | |
dataset.meta.episodes_stats[ep_idx] = ep_stats | |
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): | |
assert dataset.episodes is None | |
print("Computing episodes stats") | |
total_episodes = dataset.meta.total_episodes | |
if num_workers > 0: | |
with ThreadPoolExecutor(max_workers=num_workers) as executor: | |
futures = { | |
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx | |
for ep_idx in range(total_episodes) | |
} | |
for future in tqdm(as_completed(futures), total=total_episodes): | |
future.result() | |
else: | |
for ep_idx in tqdm(range(total_episodes)): | |
convert_episode_stats(dataset, ep_idx) | |
for ep_idx in tqdm(range(total_episodes)): | |
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) | |
def check_aggregate_stats( | |
dataset: LeRobotDataset, | |
reference_stats: dict[str, dict[str, np.ndarray]], | |
video_rtol_atol: tuple[float] = (1e-2, 1e-2), | |
default_rtol_atol: tuple[float] = (5e-6, 6e-5), | |
): | |
"""Verifies that the aggregated stats from episodes_stats are close to reference stats.""" | |
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values())) | |
for key, ft in dataset.features.items(): | |
# These values might need some fine-tuning | |
if ft["dtype"] == "video": | |
# to account for image sub-sampling | |
rtol, atol = video_rtol_atol | |
else: | |
rtol, atol = default_rtol_atol | |
for stat, val in agg_stats[key].items(): | |
if key in reference_stats and stat in reference_stats[key]: | |
err_msg = f"feature='{key}' stats='{stat}'" | |
np.testing.assert_allclose( | |
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg | |
) | |