Spaces:
Build error
Build error
File size: 4,867 Bytes
33f1db4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Data loader."""
import itertools
import numpy as np
import torch
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler
from . import utils as utils
from .build import build_dataset
def detection_collate(batch):
"""
Collate function for detection task. Concatanate bboxes, labels and
metadata from different samples in the first dimension instead of
stacking them to have a batch-size dimension.
Args:
batch (tuple or list): data batch to collate.
Returns:
(tuple): collated detection data batch.
"""
inputs, labels, video_idx, extra_data = zip(*batch)
inputs, video_idx = default_collate(inputs), default_collate(video_idx)
labels = torch.tensor(np.concatenate(labels, axis=0)).float()
collated_extra_data = {}
for key in extra_data[0].keys():
data = [d[key] for d in extra_data]
if key == "boxes" or key == "ori_boxes":
# Append idx info to the bboxes before concatenating them.
bboxes = [
np.concatenate(
[np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
)
for i in range(len(data))
]
bboxes = np.concatenate(bboxes, axis=0)
collated_extra_data[key] = torch.tensor(bboxes).float()
elif key == "metadata":
collated_extra_data[key] = torch.tensor(
list(itertools.chain(*data))
).view(-1, 2)
else:
collated_extra_data[key] = default_collate(data)
return inputs, labels, video_idx, collated_extra_data
def construct_loader(cfg, split, is_precise_bn=False):
"""
Constructs the data loader for the given dataset.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
split (str): the split of the data loader. Options include `train`,
`val`, and `test`.
"""
assert split in ["train", "val", "test"]
if split in ["train"]:
dataset_name = cfg.TRAIN.DATASET
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = True
drop_last = True
elif split in ["val"]:
dataset_name = cfg.TRAIN.DATASET
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = False
drop_last = False
elif split in ["test"]:
dataset_name = cfg.TEST.DATASET
batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = False
drop_last = False
# Construct the dataset
dataset = build_dataset(dataset_name, cfg, split)
if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)
batch_sampler = ShortCycleBatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
)
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
else:
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=(False if sampler else shuffle),
sampler=sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
return loader
def shuffle_dataset(loader, cur_epoch):
""" "
Shuffles the data.
Args:
loader (loader): data loader to perform shuffle.
cur_epoch (int): number of the current epoch.
"""
sampler = (
loader.batch_sampler.sampler
if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
else loader.sampler
)
assert isinstance(
sampler, (RandomSampler, DistributedSampler)
), "Sampler type '{}' not supported".format(type(sampler))
# RandomSampler handles shuffling automatically
if isinstance(sampler, DistributedSampler):
# DistributedSampler shuffles data based on epoch
sampler.set_epoch(cur_epoch)
|