from argparse import ArgumentParser import torch #from sparsification.glm_saga import glm_saga from sparsification import feature_helpers def safe_zip(*args): for iterable in args[1:]: if len(iterable) != len(args[0]): print("Unequally sized iterables to zip, printing lengths") for i, entry in enumerate(args): print(i, len(entry)) raise ValueError("Unequally sized iterables to zip") return zip(*args) def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes, ): print("Computing/loading deep features...") Ntotal = len(train_loader.dataset) feature_loaders = {} # Compute Features for not augmented train and test set train_loader_transforms = train_loader.dataset.transform test_loader_transforms = test_loader.dataset.transform train_loader.dataset.transform = test_loader_transforms for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): # print(f"For {mode} set...") sink_path = f"{out_dir_feats}/features_{mode}" metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" feature_ds, feature_loader = feature_helpers.compute_features(loader, model, dataset_type=args.dataset_type, pooled_output=None, batch_size=args.batch_size, num_workers=0, # args.num_workers, shuffle=(mode == 'test'), device=args.device, filename=sink_path, n_epoch=1, balance=False, ) # args.balance if mode == 'test' else False) if mode == 'train': metadata = feature_helpers.calculate_metadata(feature_loader, num_classes=num_classes, filename=metadata_path) if metadata["max_reg"]["group"] == 0.0: return None, False split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds, Ntotal, val_frac=args.val_frac, batch_size=args.batch_size, num_workers=args.num_workers, random_seed=args.random_seed, shuffle=True, balance=False) feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi]) for mi, mm in enumerate(['train', 'val'])}) else: feature_loaders[mode] = feature_loader train_loader.dataset.transform = train_loader_transforms return feature_loaders, metadata def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ): args = get_default_args() args.random_seed = seed device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_folder = log_folder / "features" feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model, feature_folder , num_classes, ) return feature_loaders, metadata, device,args def add_index_to_dataloader(loader, sample_weight=None,): return torch.utils.data.DataLoader( IndexedDataset(loader.dataset, sample_weight=sample_weight), batch_size=loader.batch_size, sampler=loader.sampler, num_workers=loader.num_workers, collate_fn=loader.collate_fn, pin_memory=loader.pin_memory, drop_last=loader.drop_last, timeout=loader.timeout, worker_init_fn=loader.worker_init_fn, multiprocessing_context=loader.multiprocessing_context ) class IndexedDataset(torch.utils.data.Dataset): def __init__(self, ds, sample_weight=None): super(torch.utils.data.Dataset, self).__init__() self.dataset = ds self.sample_weight = sample_weight def __getitem__(self, index): val = self.dataset[index] if self.sample_weight is None: return val + (index,) else: weight = self.sample_weight[index] return val + (weight, index) def __len__(self): return len(self.dataset) def get_default_args(): # Default args from glm_saga, https://github.com/MadryLab/glm_saga parser = ArgumentParser() parser.add_argument('--dataset', type=str, help='dataset name') parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]') parser.add_argument('--dataset-path', type=str, help='path to dataset') parser.add_argument('--model-path', type=str, help='path to model checkpoint') parser.add_argument('--arch', type=str, help='model architecture type') parser.add_argument('--out-path', help='location for saving results') parser.add_argument('--cache', action='store_true', help='cache deep features') parser.add_argument('--balance', action='store_true', help='balance classes for evaluation') parser.add_argument('--device', default='cuda') parser.add_argument('--random-seed', default=0) parser.add_argument('--num-workers', type=int, default=2) parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--val-frac', type=float, default=0.1) parser.add_argument('--lr-decay-factor', type=float, default=1) parser.add_argument('--lr', type=float, default=0.1) parser.add_argument('--alpha', type=float, default=0.99) parser.add_argument('--max-epochs', type=int, default=2000) parser.add_argument('--verbose', type=int, default=200) parser.add_argument('--tol', type=float, default=1e-4) parser.add_argument('--lookbehind', type=int, default=3) parser.add_argument('--lam-factor', type=float, default=0.001) parser.add_argument('--group', action='store_true') args = parser.parse_args() args = parser.parse_args() return args def select_in_loader(feature_loaders, feature_selection): for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train tensors = list(dataset.tensors) if tensors[0].shape[1] == len(feature_selection): continue tensors[0] = tensors[0][:, feature_selection] dataset.tensors = tensors for dataset in feature_loaders["test"].dataset.datasets: tensors = list(dataset.tensors) if tensors[0].shape[1] == len(feature_selection): continue tensors[0] = tensors[0][:, feature_selection] dataset.tensors = tensors return feature_loaders