Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""An online buffer for the online training loop in train.py | |
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should | |
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much | |
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it | |
supports in-place slicing and mutation which is very handy for a dynamic buffer. | |
""" | |
import os | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import torch | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | |
def _make_memmap_safe(**kwargs) -> np.memmap: | |
"""Make a numpy memmap with checks on available disk space first. | |
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape" | |
For information on dtypes: | |
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing | |
""" | |
if kwargs["mode"].startswith("w"): | |
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes | |
stats = os.statvfs(Path(kwargs["filename"]).parent) | |
available_space = stats.f_bavail * stats.f_frsize # bytes | |
if required_space >= available_space * 0.8: | |
raise RuntimeError( | |
f"You're about to take up {required_space} of {available_space} bytes available." | |
) | |
return np.memmap(**kwargs) | |
class OnlineBuffer(torch.utils.data.Dataset): | |
"""FIFO data buffer for the online training loop in train.py. | |
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training | |
loop in the same way that a LeRobotDataset would be used. | |
The underlying data structure will have data inserted in a circular fashion. Always insert after the | |
last index, and when you reach the end, wrap around to the start. | |
The data is stored in a numpy memmap. | |
""" | |
NEXT_INDEX_KEY = "_next_index" | |
OCCUPANCY_MASK_KEY = "_occupancy_mask" | |
INDEX_KEY = "index" | |
FRAME_INDEX_KEY = "frame_index" | |
EPISODE_INDEX_KEY = "episode_index" | |
TIMESTAMP_KEY = "timestamp" | |
IS_PAD_POSTFIX = "_is_pad" | |
def __init__( | |
self, | |
write_dir: str | Path, | |
data_spec: dict[str, Any] | None, | |
buffer_capacity: int | None, | |
fps: float | None = None, | |
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None, | |
): | |
""" | |
The online buffer can be provided from scratch or you can load an existing online buffer by passing | |
a `write_dir` associated with an existing buffer. | |
Args: | |
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key. | |
Note that if the files already exist, they are opened in read-write mode (used for training | |
resumption.) | |
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int], | |
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer, | |
but note that "index", "frame_index" and "episode_index" are already accounted for by this | |
class, so you don't need to include them. | |
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your | |
system's available disk space when choosing this. | |
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the | |
delta_timestamps logic. You can pass None if you are not using delta_timestamps. | |
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally | |
converted to dict[str, np.ndarray] for optimization purposes. | |
""" | |
self.set_delta_timestamps(delta_timestamps) | |
self._fps = fps | |
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from | |
# the requested frames. It is only used when `delta_timestamps` is provided. | |
# minus 1e-4 to account for possible numerical error | |
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None | |
self._buffer_capacity = buffer_capacity | |
data_spec = self._make_data_spec(data_spec, buffer_capacity) | |
Path(write_dir).mkdir(parents=True, exist_ok=True) | |
self._data = {} | |
for k, v in data_spec.items(): | |
self._data[k] = _make_memmap_safe( | |
filename=Path(write_dir) / k, | |
dtype=v["dtype"] if v is not None else None, | |
mode="r+" if (Path(write_dir) / k).exists() else "w+", | |
shape=tuple(v["shape"]) if v is not None else None, | |
) | |
def delta_timestamps(self) -> dict[str, np.ndarray] | None: | |
return self._delta_timestamps | |
def set_delta_timestamps(self, value: dict[str, list[float]] | None): | |
"""Set delta_timestamps converting the values to numpy arrays. | |
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays | |
need to be converted into numpy arrays. | |
""" | |
if value is not None: | |
self._delta_timestamps = {k: np.array(v) for k, v in value.items()} | |
else: | |
self._delta_timestamps = None | |
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: | |
"""Makes the data spec for np.memmap.""" | |
if any(k.startswith("_") for k in data_spec): | |
raise ValueError( | |
"data_spec keys should not start with '_'. This prefix is reserved for internal logic." | |
) | |
preset_keys = { | |
OnlineBuffer.INDEX_KEY, | |
OnlineBuffer.FRAME_INDEX_KEY, | |
OnlineBuffer.EPISODE_INDEX_KEY, | |
OnlineBuffer.TIMESTAMP_KEY, | |
} | |
if len(intersection := set(data_spec).intersection(preset_keys)) > 0: | |
raise ValueError( | |
f"data_spec should not contain any of {preset_keys} as these are handled internally. " | |
f"The provided data_spec has {intersection}." | |
) | |
complete_data_spec = { | |
# _next_index will be a pointer to the next index that we should start filling from when we add | |
# more data. | |
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, | |
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied | |
# with real data rather than the dummy initialization. | |
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, | |
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, | |
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, | |
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, | |
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, | |
} | |
for k, v in data_spec.items(): | |
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} | |
return complete_data_spec | |
def add_data(self, data: dict[str, np.ndarray]): | |
"""Add new data to the buffer, which could potentially mean shifting old data out. | |
The new data should contain all the frames (in order) of any number of episodes. The indices should | |
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and | |
`eval_policy` functions in `eval.py` for more information on how the data is constructed. | |
Shift the incoming data index and episode_index to continue on from the last frame. Note that this | |
will be done in place! | |
""" | |
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0: | |
raise ValueError(f"Missing data keys: {missing_keys}") | |
new_data_length = len(data[self.data_keys[0]]) | |
if not all(len(data[k]) == new_data_length for k in self.data_keys): | |
raise ValueError("All data items should have the same length") | |
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY] | |
# Sanity check to make sure that the new data indices start from 0. | |
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0 | |
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0 | |
# Shift the incoming indices if necessary. | |
if self.num_frames > 0: | |
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] | |
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] | |
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 | |
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 | |
# Insert the new data starting from next_index. It may be necessary to wrap around to the start. | |
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index)) | |
for k in self.data_keys: | |
if n_surplus == 0: | |
slc = slice(next_index, next_index + new_data_length) | |
self._data[k][slc] = data[k] | |
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True | |
else: | |
self._data[k][next_index:] = data[k][:-n_surplus] | |
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True | |
self._data[k][:n_surplus] = data[k][-n_surplus:] | |
if n_surplus == 0: | |
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length | |
else: | |
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus | |
def data_keys(self) -> list[str]: | |
keys = set(self._data) | |
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY) | |
keys.remove(OnlineBuffer.NEXT_INDEX_KEY) | |
return sorted(keys) | |
def fps(self) -> float | None: | |
return self._fps | |
def num_episodes(self) -> int: | |
return len( | |
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) | |
) | |
def num_frames(self) -> int: | |
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]) | |
def __len__(self): | |
return self.num_frames | |
def _item_to_tensors(self, item: dict) -> dict: | |
item_ = {} | |
for k, v in item.items(): | |
if isinstance(v, torch.Tensor): | |
item_[k] = v | |
elif isinstance(v, np.ndarray): | |
item_[k] = torch.from_numpy(v) | |
else: | |
item_[k] = torch.tensor(v) | |
return item_ | |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
if idx >= len(self) or idx < -len(self): | |
raise IndexError | |
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")} | |
if self.delta_timestamps is None: | |
return self._item_to_tensors(item) | |
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY] | |
current_ts = item[OnlineBuffer.TIMESTAMP_KEY] | |
episode_data_indices = np.where( | |
np.bitwise_and( | |
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index, | |
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], | |
) | |
)[0] | |
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] | |
for data_key in self.delta_timestamps: | |
# Note: The logic in this loop is copied from `load_previous_and_future_frames`. | |
# Get timestamps used as query to retrieve data of previous/future frames. | |
query_ts = current_ts + self.delta_timestamps[data_key] | |
# Compute distances between each query timestamp and all timestamps of all the frames belonging to | |
# the episode. | |
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :]) | |
argmin_ = np.argmin(dist, axis=1) | |
min_ = dist[np.arange(dist.shape[0]), argmin_] | |
is_pad = min_ > self.tolerance_s | |
# Check violated query timestamps are all outside the episode range. | |
assert ( | |
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) | |
).all(), ( | |
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" | |
") inside the episode range." | |
) | |
# Load frames for this data key. | |
item[data_key] = self._data[data_key][episode_data_indices[argmin_]] | |
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad | |
return self._item_to_tensors(item) | |
def get_data_by_key(self, key: str) -> torch.Tensor: | |
"""Returns all data for a given data key as a Tensor.""" | |
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) | |
def compute_sampler_weights( | |
offline_dataset: LeRobotDataset, | |
offline_drop_n_last_frames: int = 0, | |
online_dataset: OnlineBuffer | None = None, | |
online_sampling_ratio: float | None = None, | |
online_drop_n_last_frames: int = 0, | |
) -> torch.Tensor: | |
"""Compute the sampling weights for the online training dataloader in train.py. | |
Args: | |
offline_dataset: The LeRobotDataset used for offline pre-training. | |
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode. | |
online_dataset: The OnlineBuffer used in online training. | |
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an | |
online dataset is provided, this value must also be provided. | |
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online | |
dataset. | |
Returns: | |
Tensor of weights for [offline_dataset; online_dataset], normalized to 1. | |
Notes to maintainers: | |
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach. | |
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace | |
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature | |
is the ability to turn shuffling off. | |
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not | |
included here to avoid adding complexity. | |
""" | |
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): | |
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") | |
if (online_dataset is None) ^ (online_sampling_ratio is None): | |
raise ValueError( | |
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all." | |
) | |
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio | |
weights = [] | |
if len(offline_dataset) > 0: | |
offline_data_mask_indices = [] | |
for start_index, end_index in zip( | |
offline_dataset.episode_data_index["from"], | |
offline_dataset.episode_data_index["to"], | |
strict=True, | |
): | |
offline_data_mask_indices.extend( | |
range(start_index.item(), end_index.item() - offline_drop_n_last_frames) | |
) | |
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) | |
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True | |
weights.append( | |
torch.full( | |
size=(len(offline_dataset),), | |
fill_value=offline_sampling_ratio / offline_data_mask.sum(), | |
) | |
* offline_data_mask | |
) | |
if online_dataset is not None and len(online_dataset) > 0: | |
online_data_mask_indices = [] | |
episode_indices = online_dataset.get_data_by_key("episode_index") | |
for episode_idx in torch.unique(episode_indices): | |
where_episode = torch.where(episode_indices == episode_idx) | |
start_index = where_episode[0][0] | |
end_index = where_episode[0][-1] + 1 | |
online_data_mask_indices.extend( | |
range(start_index.item(), end_index.item() - online_drop_n_last_frames) | |
) | |
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool) | |
online_data_mask[torch.tensor(online_data_mask_indices)] = True | |
weights.append( | |
torch.full( | |
size=(len(online_dataset),), | |
fill_value=online_sampling_ratio / online_data_mask.sum(), | |
) | |
* online_data_mask | |
) | |
weights = torch.cat(weights) | |
if weights.sum() == 0: | |
weights += 1 / len(weights) | |
else: | |
weights /= weights.sum() | |
return weights | |