Spaces:
Runtime error
Runtime error
File size: 8,046 Bytes
821f875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import copy
import warnings
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler
from src.efficientvit.apps.data_provider.random_resolution import RRSController
from src.efficientvit.models.utils import val2tuple
__all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
def parse_image_size(size: int or str) -> tuple[int, int]:
if isinstance(size, str):
size = [int(val) for val in size.split("-")]
return size[0], size[1]
else:
return val2tuple(size, 2)
def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
g = torch.Generator()
g.manual_seed(seed) # set random seed before sampling validation set
rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
dropped_indexes = rand_indexes[:drop_size]
remaining_indexes = rand_indexes[drop_size:]
dropped_dataset = copy.deepcopy(dataset)
for key in keys:
setattr(
dropped_dataset,
key,
[getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
)
setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
return dataset, dropped_dataset
class DataProvider:
data_keys = ("samples",)
mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
SUB_SEED = 937162211 # random seed for sampling subset
VALID_SEED = 2147483647 # random seed for the validation set
name: str
def __init__(
self,
train_batch_size: int,
test_batch_size: int or None,
valid_size: int or float or None,
n_worker: int,
image_size: int or list[int] or str or list[str],
num_replicas: int or None = None,
rank: int or None = None,
train_ratio: float or None = None,
drop_last: bool = False,
):
warnings.filterwarnings("ignore")
super().__init__()
# batch_size & valid_size
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size or self.train_batch_size
self.valid_size = valid_size
# image size
if isinstance(image_size, list):
self.image_size = [parse_image_size(size) for size in image_size]
self.image_size.sort() # e.g., 160 -> 224
RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
else:
self.image_size = parse_image_size(image_size)
RRSController.IMAGE_SIZE_LIST = [self.image_size]
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
# distributed configs
self.num_replicas = num_replicas
self.rank = rank
# build datasets
train_dataset, val_dataset, test_dataset = self.build_datasets()
if train_ratio is not None and train_ratio < 1.0:
assert 0 < train_ratio < 1
_, train_dataset = random_drop_data(
train_dataset,
int(train_ratio * len(train_dataset)),
self.SUB_SEED,
self.data_keys,
)
# build data loader
self.train = self.build_dataloader(
train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
)
self.valid = self.build_dataloader(
val_dataset, test_batch_size, n_worker, drop_last=False, train=False
)
self.test = self.build_dataloader(
test_dataset, test_batch_size, n_worker, drop_last=False, train=False
)
if self.valid is None:
self.valid = self.test
self.sub_train = None
@property
def data_shape(self) -> tuple[int, ...]:
return 3, self.active_image_size[0], self.active_image_size[1]
def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
raise NotImplementedError
def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
raise NotImplementedError
def build_datasets(self) -> tuple[any, any, any]:
raise NotImplementedError
def build_dataloader(
self,
dataset: any or None,
batch_size: int,
n_worker: int,
drop_last: bool,
train: bool,
):
if dataset is None:
return None
if isinstance(self.image_size, list) and train:
from efficientvit.apps.data_provider.random_resolution._data_loader import \
RRSDataLoader
dataloader_class = RRSDataLoader
else:
dataloader_class = torch.utils.data.DataLoader
if self.num_replicas is None:
return dataloader_class(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_worker,
pin_memory=True,
drop_last=drop_last,
)
else:
sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
return dataloader_class(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=n_worker,
pin_memory=True,
drop_last=drop_last,
)
def set_epoch(self, epoch: int) -> None:
RRSController.set_epoch(epoch, len(self.train))
if isinstance(self.train.sampler, DistributedSampler):
self.train.sampler.set_epoch(epoch)
def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
self.active_image_size = val2tuple(new_size, 2)
new_transform = self.build_valid_transform(self.active_image_size)
# change the transform of the valid and test set
self.valid.dataset.transform = self.test.dataset.transform = new_transform
def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
if self.valid_size is not None:
if 0 < self.valid_size < 1:
valid_size = int(self.valid_size * len(train_dataset))
else:
assert self.valid_size >= 1
valid_size = int(self.valid_size)
train_dataset, val_dataset = random_drop_data(
train_dataset,
valid_size,
self.VALID_SEED,
self.data_keys,
)
val_dataset.transform = valid_transform
else:
val_dataset = None
return train_dataset, val_dataset
def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
# used for resetting BN running statistics
if self.sub_train is None:
self.sub_train = {}
if self.active_image_size in self.sub_train:
return self.sub_train[self.active_image_size]
# construct dataset and dataloader
train_dataset = copy.deepcopy(self.train.dataset)
if n_samples < len(train_dataset):
_, train_dataset = random_drop_data(
train_dataset,
n_samples,
self.SUB_SEED,
self.data_keys,
)
RRSController.ACTIVE_SIZE = self.active_image_size
train_dataset.transform = self.build_train_transform(
image_size=self.active_image_size
)
data_loader = self.build_dataloader(
train_dataset, batch_size, self.train.num_workers, True, False
)
# pre-fetch data
self.sub_train[self.active_image_size] = [
data
for data in data_loader
for _ in range(max(1, n_samples // len(train_dataset)))
]
return self.sub_train[self.active_image_size]
|