File size: 17,708 Bytes
529ed6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An online buffer for the online training loop in train.py

Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
supports in-place slicing and mutation which is very handy for a dynamic buffer.
"""

import os
from pathlib import Path
from typing import Any

import numpy as np
import torch

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset


def _make_memmap_safe(**kwargs) -> np.memmap:
    """Make a numpy memmap with checks on available disk space first.

    Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"

    For information on dtypes:
    https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
    """
    if kwargs["mode"].startswith("w"):
        required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"])  # bytes
        stats = os.statvfs(Path(kwargs["filename"]).parent)
        available_space = stats.f_bavail * stats.f_frsize  # bytes
        if required_space >= available_space * 0.8:
            raise RuntimeError(
                f"You're about to take up {required_space} of {available_space} bytes available."
            )
    return np.memmap(**kwargs)


class OnlineBuffer(torch.utils.data.Dataset):
    """FIFO data buffer for the online training loop in train.py.

    Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
    loop in the same way that a LeRobotDataset would be used.

    The underlying data structure will have data inserted in a circular fashion. Always insert after the
    last index, and when you reach the end, wrap around to the start.

    The data is stored in a numpy memmap.
    """

    NEXT_INDEX_KEY = "_next_index"
    OCCUPANCY_MASK_KEY = "_occupancy_mask"
    INDEX_KEY = "index"
    FRAME_INDEX_KEY = "frame_index"
    EPISODE_INDEX_KEY = "episode_index"
    TIMESTAMP_KEY = "timestamp"
    IS_PAD_POSTFIX = "_is_pad"

    def __init__(
        self,
        write_dir: str | Path,
        data_spec: dict[str, Any] | None,
        buffer_capacity: int | None,
        fps: float | None = None,
        delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
    ):
        """
        The online buffer can be provided from scratch or you can load an existing online buffer by passing
        a `write_dir` associated with an existing buffer.

        Args:
            write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
                Note that if the files already exist, they are opened in read-write mode (used for training
                resumption.)
            data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
                "dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
                but note that "index", "frame_index" and "episode_index" are already accounted for by this
                class, so you don't need to include them.
            buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
                system's available disk space when choosing this.
            fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
                 delta_timestamps logic. You can pass None if you are not using delta_timestamps.
            delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
                converted to dict[str, np.ndarray] for optimization purposes.

        """
        self.set_delta_timestamps(delta_timestamps)
        self._fps = fps
        # Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
        # the requested frames. It is only used when `delta_timestamps` is provided.
        # minus 1e-4 to account for possible numerical error
        self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
        self._buffer_capacity = buffer_capacity
        data_spec = self._make_data_spec(data_spec, buffer_capacity)
        Path(write_dir).mkdir(parents=True, exist_ok=True)
        self._data = {}
        for k, v in data_spec.items():
            self._data[k] = _make_memmap_safe(
                filename=Path(write_dir) / k,
                dtype=v["dtype"] if v is not None else None,
                mode="r+" if (Path(write_dir) / k).exists() else "w+",
                shape=tuple(v["shape"]) if v is not None else None,
            )

    @property
    def delta_timestamps(self) -> dict[str, np.ndarray] | None:
        return self._delta_timestamps

    def set_delta_timestamps(self, value: dict[str, list[float]] | None):
        """Set delta_timestamps converting the values to numpy arrays.

        The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
        need to be converted into numpy arrays.
        """
        if value is not None:
            self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
        else:
            self._delta_timestamps = None

    def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
        """Makes the data spec for np.memmap."""
        if any(k.startswith("_") for k in data_spec):
            raise ValueError(
                "data_spec keys should not start with '_'. This prefix is reserved for internal logic."
            )
        preset_keys = {
            OnlineBuffer.INDEX_KEY,
            OnlineBuffer.FRAME_INDEX_KEY,
            OnlineBuffer.EPISODE_INDEX_KEY,
            OnlineBuffer.TIMESTAMP_KEY,
        }
        if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
            raise ValueError(
                f"data_spec should not contain any of {preset_keys} as these are handled internally. "
                f"The provided data_spec has {intersection}."
            )
        complete_data_spec = {
            # _next_index will be a pointer to the next index that we should start filling from when we add
            # more data.
            OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
            # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
            # with real data rather than the dummy initialization.
            OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
            OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
            OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
            OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
            OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
        }
        for k, v in data_spec.items():
            complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
        return complete_data_spec

    def add_data(self, data: dict[str, np.ndarray]):
        """Add new data to the buffer, which could potentially mean shifting old data out.

        The new data should contain all the frames (in order) of any number of episodes. The indices should
        start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
        `eval_policy` functions in `eval.py` for more information on how the data is constructed.

        Shift the incoming data index and episode_index to continue on from the last frame. Note that this
        will be done in place!
        """
        if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
            raise ValueError(f"Missing data keys: {missing_keys}")
        new_data_length = len(data[self.data_keys[0]])
        if not all(len(data[k]) == new_data_length for k in self.data_keys):
            raise ValueError("All data items should have the same length")

        next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]

        # Sanity check to make sure that the new data indices start from 0.
        assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
        assert data[OnlineBuffer.INDEX_KEY][0].item() == 0

        # Shift the incoming indices if necessary.
        if self.num_frames > 0:
            last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
            last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
            data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
            data[OnlineBuffer.INDEX_KEY] += last_data_index + 1

        # Insert the new data starting from next_index. It may be necessary to wrap around to the start.
        n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
        for k in self.data_keys:
            if n_surplus == 0:
                slc = slice(next_index, next_index + new_data_length)
                self._data[k][slc] = data[k]
                self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
            else:
                self._data[k][next_index:] = data[k][:-n_surplus]
                self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
                self._data[k][:n_surplus] = data[k][-n_surplus:]
        if n_surplus == 0:
            self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
        else:
            self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus

    @property
    def data_keys(self) -> list[str]:
        keys = set(self._data)
        keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
        keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
        return sorted(keys)

    @property
    def fps(self) -> float | None:
        return self._fps

    @property
    def num_episodes(self) -> int:
        return len(
            np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
        )

    @property
    def num_frames(self) -> int:
        return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])

    def __len__(self):
        return self.num_frames

    def _item_to_tensors(self, item: dict) -> dict:
        item_ = {}
        for k, v in item.items():
            if isinstance(v, torch.Tensor):
                item_[k] = v
            elif isinstance(v, np.ndarray):
                item_[k] = torch.from_numpy(v)
            else:
                item_[k] = torch.tensor(v)
        return item_

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        if idx >= len(self) or idx < -len(self):
            raise IndexError

        item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}

        if self.delta_timestamps is None:
            return self._item_to_tensors(item)

        episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
        current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
        episode_data_indices = np.where(
            np.bitwise_and(
                self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
                self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
            )
        )[0]
        episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]

        for data_key in self.delta_timestamps:
            # Note: The logic in this loop is copied from `load_previous_and_future_frames`.
            # Get timestamps used as query to retrieve data of previous/future frames.
            query_ts = current_ts + self.delta_timestamps[data_key]

            # Compute distances between each query timestamp and all timestamps of all the frames belonging to
            # the episode.
            dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
            argmin_ = np.argmin(dist, axis=1)
            min_ = dist[np.arange(dist.shape[0]), argmin_]

            is_pad = min_ > self.tolerance_s

            # Check violated query timestamps are all outside the episode range.
            assert (
                (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
            ).all(), (
                f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
                ") inside the episode range."
            )

            # Load frames for this data key.
            item[data_key] = self._data[data_key][episode_data_indices[argmin_]]

            item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad

        return self._item_to_tensors(item)

    def get_data_by_key(self, key: str) -> torch.Tensor:
        """Returns all data for a given data key as a Tensor."""
        return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])


def compute_sampler_weights(
    offline_dataset: LeRobotDataset,
    offline_drop_n_last_frames: int = 0,
    online_dataset: OnlineBuffer | None = None,
    online_sampling_ratio: float | None = None,
    online_drop_n_last_frames: int = 0,
) -> torch.Tensor:
    """Compute the sampling weights for the online training dataloader in train.py.

    Args:
        offline_dataset: The LeRobotDataset used for offline pre-training.
        online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
        online_dataset: The OnlineBuffer used in online training.
        online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
            online dataset is provided, this value must also be provided.
        online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
            dataset.
    Returns:
        Tensor of weights for [offline_dataset; online_dataset], normalized to 1.

    Notes to maintainers:
        - This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
        - When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
          `EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
          is the ability to turn shuffling off.
        - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
          included here to avoid adding complexity.
    """
    if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
        raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
    if (online_dataset is None) ^ (online_sampling_ratio is None):
        raise ValueError(
            "`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
        )
    offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio

    weights = []

    if len(offline_dataset) > 0:
        offline_data_mask_indices = []
        for start_index, end_index in zip(
            offline_dataset.episode_data_index["from"],
            offline_dataset.episode_data_index["to"],
            strict=True,
        ):
            offline_data_mask_indices.extend(
                range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
            )
        offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
        offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
        weights.append(
            torch.full(
                size=(len(offline_dataset),),
                fill_value=offline_sampling_ratio / offline_data_mask.sum(),
            )
            * offline_data_mask
        )

    if online_dataset is not None and len(online_dataset) > 0:
        online_data_mask_indices = []
        episode_indices = online_dataset.get_data_by_key("episode_index")
        for episode_idx in torch.unique(episode_indices):
            where_episode = torch.where(episode_indices == episode_idx)
            start_index = where_episode[0][0]
            end_index = where_episode[0][-1] + 1
            online_data_mask_indices.extend(
                range(start_index.item(), end_index.item() - online_drop_n_last_frames)
            )
        online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
        online_data_mask[torch.tensor(online_data_mask_indices)] = True
        weights.append(
            torch.full(
                size=(len(online_dataset),),
                fill_value=online_sampling_ratio / online_data_mask.sum(),
            )
            * online_data_mask
        )

    weights = torch.cat(weights)

    if weights.sum() == 0:
        weights += 1 / len(weights)
    else:
        weights /= weights.sum()

    return weights