File size: 7,839 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from argparse import ArgumentParser

import torch

#from sparsification.glm_saga import glm_saga
from sparsification import feature_helpers


def safe_zip(*args):
    for iterable in args[1:]:
        if len(iterable) != len(args[0]):
            print("Unequally sized iterables to zip, printing lengths")
            for i, entry in enumerate(args):
                print(i, len(entry))
            raise ValueError("Unequally sized iterables to zip")
    return zip(*args)


def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes,
                                ):
    print("Computing/loading deep features...")

    Ntotal = len(train_loader.dataset)
    feature_loaders = {}
    # Compute Features for not augmented train and test set
    train_loader_transforms = train_loader.dataset.transform
    test_loader_transforms = test_loader.dataset.transform
    train_loader.dataset.transform = test_loader_transforms
    for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]):  #
        print(f"For {mode} set...")

        sink_path = f"{out_dir_feats}/features_{mode}"
        metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"

        feature_ds, feature_loader = feature_helpers.compute_features(loader,
                                                                      model,
                                                                      dataset_type=args.dataset_type,
                                                                      pooled_output=None,
                                                                      batch_size=args.batch_size,
                                                                      num_workers=0,  # args.num_workers,
                                                                      shuffle=(mode == 'test'),
                                                                      device=args.device,
                                                                      filename=sink_path, n_epoch=1,
                                                                      balance=False,
                                                                      )  # args.balance if mode == 'test' else False)

        if mode == 'train':
            metadata = feature_helpers.calculate_metadata(feature_loader,
                                                          num_classes=num_classes,
                                                          filename=metadata_path)
            if metadata["max_reg"]["group"] == 0.0:
                return None, False
            split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds,
                                                                          Ntotal,
                                                                          val_frac=args.val_frac,
                                                                          batch_size=args.batch_size,
                                                                          num_workers=args.num_workers,
                                                                          random_seed=args.random_seed,
                                                                          shuffle=True,
                                                                          balance=False)
            feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi])
                                    for mi, mm in enumerate(['train', 'val'])})

        else:
            feature_loaders[mode] = feature_loader
    train_loader.dataset.transform = train_loader_transforms
    return feature_loaders, metadata

def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ):
    args = get_default_args()
    args.random_seed = seed
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    feature_folder = log_folder / "features"
    feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model,
                                                               feature_folder
                                                               ,
                                                               num_classes,
                                                               )
    return feature_loaders, metadata, device,args
def add_index_to_dataloader(loader, sample_weight=None,):
    return torch.utils.data.DataLoader(
        IndexedDataset(loader.dataset, sample_weight=sample_weight),
        batch_size=loader.batch_size,
        sampler=loader.sampler,
        num_workers=loader.num_workers,
        collate_fn=loader.collate_fn,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
        timeout=loader.timeout,
        worker_init_fn=loader.worker_init_fn,
        multiprocessing_context=loader.multiprocessing_context
    )


class IndexedDataset(torch.utils.data.Dataset):
    def __init__(self, ds, sample_weight=None):
        super(torch.utils.data.Dataset, self).__init__()
        self.dataset = ds
        self.sample_weight = sample_weight

    def __getitem__(self, index):
        val = self.dataset[index]
        if self.sample_weight is None:
            return val + (index,)
        else:
            weight = self.sample_weight[index]
            return val + (weight, index)

    def __len__(self):
        return len(self.dataset)


def get_default_args():
    # Default args from glm_saga, https://github.com/MadryLab/glm_saga
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, help='dataset name')
    parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
    parser.add_argument('--dataset-path', type=str, help='path to dataset')
    parser.add_argument('--model-path', type=str, help='path to model checkpoint')
    parser.add_argument('--arch', type=str, help='model architecture type')
    parser.add_argument('--out-path', help='location for saving results')
    parser.add_argument('--cache', action='store_true', help='cache deep features')
    parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')

    parser.add_argument('--device', default='cuda')
    parser.add_argument('--random-seed', default=0)
    parser.add_argument('--num-workers', type=int, default=2)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--val-frac', type=float, default=0.1)
    parser.add_argument('--lr-decay-factor', type=float, default=1)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--alpha', type=float, default=0.99)
    parser.add_argument('--max-epochs', type=int, default=2000)
    parser.add_argument('--verbose', type=int, default=200)
    parser.add_argument('--tol', type=float, default=1e-4)
    parser.add_argument('--lookbehind', type=int, default=3)
    parser.add_argument('--lam-factor', type=float, default=0.001)
    parser.add_argument('--group', action='store_true')
    args = parser.parse_args()

    args = parser.parse_args()
    return args


def select_in_loader(feature_loaders, feature_selection):
    for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train
        tensors = list(dataset.tensors)
        if tensors[0].shape[1] == len(feature_selection):
            continue
        tensors[0] = tensors[0][:, feature_selection]
        dataset.tensors = tensors
    for dataset in feature_loaders["test"].dataset.datasets:
        tensors = list(dataset.tensors)
        if tensors[0].shape[1] == len(feature_selection):
            continue
        tensors[0] = tensors[0][:, feature_selection]
        dataset.tensors = tensors
    return feature_loaders