Spaces:
Sleeping
Sleeping
File size: 7,839 Bytes
8d4ee22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from argparse import ArgumentParser
import torch
#from sparsification.glm_saga import glm_saga
from sparsification import feature_helpers
def safe_zip(*args):
for iterable in args[1:]:
if len(iterable) != len(args[0]):
print("Unequally sized iterables to zip, printing lengths")
for i, entry in enumerate(args):
print(i, len(entry))
raise ValueError("Unequally sized iterables to zip")
return zip(*args)
def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes,
):
print("Computing/loading deep features...")
Ntotal = len(train_loader.dataset)
feature_loaders = {}
# Compute Features for not augmented train and test set
train_loader_transforms = train_loader.dataset.transform
test_loader_transforms = test_loader.dataset.transform
train_loader.dataset.transform = test_loader_transforms
for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): #
print(f"For {mode} set...")
sink_path = f"{out_dir_feats}/features_{mode}"
metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"
feature_ds, feature_loader = feature_helpers.compute_features(loader,
model,
dataset_type=args.dataset_type,
pooled_output=None,
batch_size=args.batch_size,
num_workers=0, # args.num_workers,
shuffle=(mode == 'test'),
device=args.device,
filename=sink_path, n_epoch=1,
balance=False,
) # args.balance if mode == 'test' else False)
if mode == 'train':
metadata = feature_helpers.calculate_metadata(feature_loader,
num_classes=num_classes,
filename=metadata_path)
if metadata["max_reg"]["group"] == 0.0:
return None, False
split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds,
Ntotal,
val_frac=args.val_frac,
batch_size=args.batch_size,
num_workers=args.num_workers,
random_seed=args.random_seed,
shuffle=True,
balance=False)
feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi])
for mi, mm in enumerate(['train', 'val'])})
else:
feature_loaders[mode] = feature_loader
train_loader.dataset.transform = train_loader_transforms
return feature_loaders, metadata
def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ):
args = get_default_args()
args.random_seed = seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_folder = log_folder / "features"
feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model,
feature_folder
,
num_classes,
)
return feature_loaders, metadata, device,args
def add_index_to_dataloader(loader, sample_weight=None,):
return torch.utils.data.DataLoader(
IndexedDataset(loader.dataset, sample_weight=sample_weight),
batch_size=loader.batch_size,
sampler=loader.sampler,
num_workers=loader.num_workers,
collate_fn=loader.collate_fn,
pin_memory=loader.pin_memory,
drop_last=loader.drop_last,
timeout=loader.timeout,
worker_init_fn=loader.worker_init_fn,
multiprocessing_context=loader.multiprocessing_context
)
class IndexedDataset(torch.utils.data.Dataset):
def __init__(self, ds, sample_weight=None):
super(torch.utils.data.Dataset, self).__init__()
self.dataset = ds
self.sample_weight = sample_weight
def __getitem__(self, index):
val = self.dataset[index]
if self.sample_weight is None:
return val + (index,)
else:
weight = self.sample_weight[index]
return val + (weight, index)
def __len__(self):
return len(self.dataset)
def get_default_args():
# Default args from glm_saga, https://github.com/MadryLab/glm_saga
parser = ArgumentParser()
parser.add_argument('--dataset', type=str, help='dataset name')
parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
parser.add_argument('--dataset-path', type=str, help='path to dataset')
parser.add_argument('--model-path', type=str, help='path to model checkpoint')
parser.add_argument('--arch', type=str, help='model architecture type')
parser.add_argument('--out-path', help='location for saving results')
parser.add_argument('--cache', action='store_true', help='cache deep features')
parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')
parser.add_argument('--device', default='cuda')
parser.add_argument('--random-seed', default=0)
parser.add_argument('--num-workers', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--val-frac', type=float, default=0.1)
parser.add_argument('--lr-decay-factor', type=float, default=1)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--alpha', type=float, default=0.99)
parser.add_argument('--max-epochs', type=int, default=2000)
parser.add_argument('--verbose', type=int, default=200)
parser.add_argument('--tol', type=float, default=1e-4)
parser.add_argument('--lookbehind', type=int, default=3)
parser.add_argument('--lam-factor', type=float, default=0.001)
parser.add_argument('--group', action='store_true')
args = parser.parse_args()
args = parser.parse_args()
return args
def select_in_loader(feature_loaders, feature_selection):
for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train
tensors = list(dataset.tensors)
if tensors[0].shape[1] == len(feature_selection):
continue
tensors[0] = tensors[0][:, feature_selection]
dataset.tensors = tensors
for dataset in feature_loaders["test"].dataset.datasets:
tensors = list(dataset.tensors)
if tensors[0].shape[1] == len(feature_selection):
continue
tensors[0] = tensors[0][:, feature_selection]
dataset.tensors = tensors
return feature_loaders
|