Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import pickle | |
import numpy as np | |
class ShardedTensor(object): | |
def __init__(self, data, starts): | |
self.data = data | |
self.starts = starts | |
assert self.starts[0] == 0 | |
assert self.starts[-1] == len(self.data) | |
assert (self.starts[1:] >= self.starts[:-1]).all() | |
assert (self.starts > -1).all() | |
def from_list(xs): | |
starts = np.full((len(xs) + 1,), -1, dtype=np.long) | |
data = np.concatenate(xs, axis=0) | |
starts[0] = 0 | |
for i, x in enumerate(xs): | |
starts[i + 1] = starts[i] + x.shape[0] | |
assert (starts > -1).all() | |
return ShardedTensor(data, starts) | |
def __getitem__(self, i): | |
return self.data[self.starts[i] : self.starts[i + 1]] | |
def __len__(self): | |
return len(self.starts) - 1 | |
def lengths(self): | |
return self.starts[1:] - self.starts[:-1] | |
def save(self, path): | |
np.save(path + "_starts", self.starts) | |
np.save(path + "_data", self.data) | |
def load(path, mmap_mode=None): | |
starts = np.load(path + "_starts.npy", mmap_mode) | |
data = np.load(path + "_data.npy", mmap_mode) | |
return ShardedTensor(data, starts) | |