|
from argparse import ArgumentParser |
|
|
|
import torch |
|
|
|
|
|
from sparsification import feature_helpers |
|
|
|
|
|
def safe_zip(*args): |
|
for iterable in args[1:]: |
|
if len(iterable) != len(args[0]): |
|
print("Unequally sized iterables to zip, printing lengths") |
|
for i, entry in enumerate(args): |
|
print(i, len(entry)) |
|
raise ValueError("Unequally sized iterables to zip") |
|
return zip(*args) |
|
|
|
|
|
def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes, |
|
): |
|
print("Computing/loading deep features...") |
|
|
|
Ntotal = len(train_loader.dataset) |
|
feature_loaders = {} |
|
|
|
train_loader_transforms = train_loader.dataset.transform |
|
test_loader_transforms = test_loader.dataset.transform |
|
train_loader.dataset.transform = test_loader_transforms |
|
for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): |
|
print(f"For {mode} set...") |
|
|
|
sink_path = f"{out_dir_feats}/features_{mode}" |
|
metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" |
|
|
|
feature_ds, feature_loader = feature_helpers.compute_features(loader, |
|
model, |
|
dataset_type=args.dataset_type, |
|
pooled_output=None, |
|
batch_size=args.batch_size, |
|
num_workers=0, |
|
shuffle=(mode == 'test'), |
|
device=args.device, |
|
filename=sink_path, n_epoch=1, |
|
balance=False, |
|
) |
|
|
|
if mode == 'train': |
|
metadata = feature_helpers.calculate_metadata(feature_loader, |
|
num_classes=num_classes, |
|
filename=metadata_path) |
|
if metadata["max_reg"]["group"] == 0.0: |
|
return None, False |
|
split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds, |
|
Ntotal, |
|
val_frac=args.val_frac, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
random_seed=args.random_seed, |
|
shuffle=True, |
|
balance=False) |
|
feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi]) |
|
for mi, mm in enumerate(['train', 'val'])}) |
|
|
|
else: |
|
feature_loaders[mode] = feature_loader |
|
train_loader.dataset.transform = train_loader_transforms |
|
return feature_loaders, metadata |
|
|
|
def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ): |
|
args = get_default_args() |
|
args.random_seed = seed |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
feature_folder = log_folder / "features" |
|
feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model, |
|
feature_folder |
|
, |
|
num_classes, |
|
) |
|
return feature_loaders, metadata, device,args |
|
def add_index_to_dataloader(loader, sample_weight=None,): |
|
return torch.utils.data.DataLoader( |
|
IndexedDataset(loader.dataset, sample_weight=sample_weight), |
|
batch_size=loader.batch_size, |
|
sampler=loader.sampler, |
|
num_workers=loader.num_workers, |
|
collate_fn=loader.collate_fn, |
|
pin_memory=loader.pin_memory, |
|
drop_last=loader.drop_last, |
|
timeout=loader.timeout, |
|
worker_init_fn=loader.worker_init_fn, |
|
multiprocessing_context=loader.multiprocessing_context |
|
) |
|
|
|
|
|
class IndexedDataset(torch.utils.data.Dataset): |
|
def __init__(self, ds, sample_weight=None): |
|
super(torch.utils.data.Dataset, self).__init__() |
|
self.dataset = ds |
|
self.sample_weight = sample_weight |
|
|
|
def __getitem__(self, index): |
|
val = self.dataset[index] |
|
if self.sample_weight is None: |
|
return val + (index,) |
|
else: |
|
weight = self.sample_weight[index] |
|
return val + (weight, index) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
|
|
def get_default_args(): |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('--dataset', type=str, help='dataset name') |
|
parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]') |
|
parser.add_argument('--dataset-path', type=str, help='path to dataset') |
|
parser.add_argument('--model-path', type=str, help='path to model checkpoint') |
|
parser.add_argument('--arch', type=str, help='model architecture type') |
|
parser.add_argument('--out-path', help='location for saving results') |
|
parser.add_argument('--cache', action='store_true', help='cache deep features') |
|
parser.add_argument('--balance', action='store_true', help='balance classes for evaluation') |
|
|
|
parser.add_argument('--device', default='cuda') |
|
parser.add_argument('--random-seed', default=0) |
|
parser.add_argument('--num-workers', type=int, default=2) |
|
parser.add_argument('--batch-size', type=int, default=256) |
|
parser.add_argument('--val-frac', type=float, default=0.1) |
|
parser.add_argument('--lr-decay-factor', type=float, default=1) |
|
parser.add_argument('--lr', type=float, default=0.1) |
|
parser.add_argument('--alpha', type=float, default=0.99) |
|
parser.add_argument('--max-epochs', type=int, default=2000) |
|
parser.add_argument('--verbose', type=int, default=200) |
|
parser.add_argument('--tol', type=float, default=1e-4) |
|
parser.add_argument('--lookbehind', type=int, default=3) |
|
parser.add_argument('--lam-factor', type=float, default=0.001) |
|
parser.add_argument('--group', action='store_true') |
|
args = parser.parse_args() |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def select_in_loader(feature_loaders, feature_selection): |
|
for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: |
|
tensors = list(dataset.tensors) |
|
if tensors[0].shape[1] == len(feature_selection): |
|
continue |
|
tensors[0] = tensors[0][:, feature_selection] |
|
dataset.tensors = tensors |
|
for dataset in feature_loaders["test"].dataset.datasets: |
|
tensors = list(dataset.tensors) |
|
if tensors[0].shape[1] == len(feature_selection): |
|
continue |
|
tensors[0] = tensors[0][:, feature_selection] |
|
dataset.tensors = tensors |
|
return feature_loaders |
|
|
|
|