Spaces:
Build error
Build error
| import torch | |
| from utils.hparams import hparams | |
| import numpy as np | |
| import os | |
| class BaseDataset(torch.utils.data.Dataset): | |
| ''' | |
| Base class for datasets. | |
| 1. *ordered_indices*: | |
| if self.shuffle == True, shuffle the indices; | |
| if self.sort_by_len == True, sort data by length; | |
| 2. *sizes*: | |
| clipped length if "max_frames" is set; | |
| 3. *num_tokens*: | |
| unclipped length. | |
| Subclasses should define: | |
| 1. *collate*: | |
| take the longest data, pad other data to the same length; | |
| 2. *__getitem__*: | |
| the index function. | |
| ''' | |
| def __init__(self, shuffle): | |
| super().__init__() | |
| self.hparams = hparams | |
| self.shuffle = shuffle | |
| self.sort_by_len = hparams['sort_by_len'] | |
| self.sizes = None | |
| def _sizes(self): | |
| return self.sizes | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |
| def collater(self, samples): | |
| raise NotImplementedError | |
| def __len__(self): | |
| return len(self._sizes) | |
| def num_tokens(self, index): | |
| return self.size(index) | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| size = min(self._sizes[index], hparams['max_frames']) | |
| return size | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| if self.shuffle: | |
| indices = np.random.permutation(len(self)) | |
| if self.sort_by_len: | |
| indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] | |
| # 先random, 然后稳定排序, 保证排序后同长度的数据顺序是依照random permutation的 (被其随机打乱). | |
| else: | |
| indices = np.arange(len(self)) | |
| return indices | |
| def num_workers(self): | |
| return int(os.getenv('NUM_WORKERS', hparams['ds_workers'])) | |