|
import logging |
|
import random |
|
|
|
from torch.utils.data import Dataset, Sampler |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BucketSampler(Sampler): |
|
r""" |
|
PyTorch Sampler that groups 3D data by height, width and frames. |
|
|
|
Args: |
|
data_source (`VideoDataset`): |
|
A PyTorch dataset object that is an instance of `VideoDataset`. |
|
batch_size (`int`, defaults to `8`): |
|
The batch size to use for training. |
|
shuffle (`bool`, defaults to `True`): |
|
Whether or not to shuffle the data in each batch before dispatching to dataloader. |
|
drop_last (`bool`, defaults to `False`): |
|
Whether or not to drop incomplete buckets of data after completely iterating over all data |
|
in the dataset. If set to True, only batches that have `batch_size` number of entries will |
|
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed |
|
and batches that do not have `batch_size` number of entries will also be yielded. |
|
""" |
|
|
|
def __init__( |
|
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False |
|
) -> None: |
|
self.data_source = data_source |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.drop_last = drop_last |
|
|
|
self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets} |
|
|
|
self._raised_warning_for_drop_last = False |
|
|
|
def __len__(self): |
|
if self.drop_last and not self._raised_warning_for_drop_last: |
|
self._raised_warning_for_drop_last = True |
|
logger.warning( |
|
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." |
|
) |
|
return (len(self.data_source) + self.batch_size - 1) // self.batch_size |
|
|
|
def __iter__(self): |
|
for index, data in enumerate(self.data_source): |
|
video_metadata = data["video_metadata"] |
|
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] |
|
|
|
self.buckets[(f, h, w)].append(data) |
|
if len(self.buckets[(f, h, w)]) == self.batch_size: |
|
if self.shuffle: |
|
random.shuffle(self.buckets[(f, h, w)]) |
|
yield self.buckets[(f, h, w)] |
|
del self.buckets[(f, h, w)] |
|
self.buckets[(f, h, w)] = [] |
|
|
|
if self.drop_last: |
|
return |
|
|
|
for fhw, bucket in list(self.buckets.items()): |
|
if len(bucket) == 0: |
|
continue |
|
if self.shuffle: |
|
random.shuffle(bucket) |
|
yield bucket |
|
del self.buckets[fhw] |
|
self.buckets[fhw] = [] |
|
|