File size: 4,694 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
import torch
from tqdm import trange

from FeatureDiversityLoss import FeatureDiversityLoss
from architectures.model_mapping import get_model
from configs.architecture_params import architecture_params
from configs.dataset_params import dataset_constants
from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics
from finetuning.map_function import finetune
from get_data import get_data
from saving.logging import Tee
from saving.utils import json_save
from train import train, test
from training.optim import get_optimizer, get_scheduler_for_model


def main(dataset, arch,seed=None, model_type="qsenn", do_dense=True,crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False):
    # create random seed, if seed is None
    if seed is None:
        seed = np.random.randint(0, 1000000)
    np.random.seed(seed)
    torch.manual_seed(seed)
    dataset_key = dataset
    if crop:
        assert dataset in ["CUB2011","TravelingBirds"]
        dataset_key += "_crop"
    log_dir = Path.home()/f"tmp/{arch}/{dataset_key}/{seed}/"
    log_dir.mkdir(parents=True, exist_ok=True)
    tee = Tee(log_dir / "log.txt") # save log to file
    n_classes = dataset_constants[dataset]["num_classes"]
    train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size)
    model = get_model(arch, n_classes, reduced_strides)
    fdl = FeatureDiversityLoss(architecture_params[arch]["beta"], model.linear)
    OptimizationSchedule = get_scheduler_for_model(model_type, dataset)
    optimizer, schedule, dense_epochs =get_optimizer(model,   OptimizationSchedule)
    if not os.path.exists(log_dir / "Trained_DenseModel.pth"):
        if do_dense:
            for epoch in trange(dense_epochs):
                model = train(model, train_loader, optimizer, fdl, epoch)
                schedule.step()
                if epoch % 5 == 0:
                    test(model, test_loader,epoch)
        else:
            print("Using pretrained model, only makes sense for ImageNet")
        torch.save(model.state_dict(), os.path.join(log_dir, f"Trained_DenseModel.pth"))
    else:
        model.load_state_dict(torch.load(log_dir / "Trained_DenseModel.pth"))
    if not  os.path.exists( log_dir/f"Results_DenseModel.json"):
        metrics_dense = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader)
        json_save(os.path.join(log_dir, f"Results_DenseModel.json"), metrics_dense)
    final_model = finetune(model_type, model, train_loader, test_loader, log_dir, n_classes, seed, architecture_params[arch]["beta"], OptimizationSchedule, n_per_class, n_features)
    torch.save(final_model.state_dict(), os.path.join(log_dir,f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth"))
    metrics_finetuned = eval_model_on_all_qsenn_metrics(final_model, test_loader, train_loader)
    json_save(os.path.join(log_dir, f"Results_{model_type}_{n_features}_{n_per_class}_FinetunedModel.json"), metrics_finetuned)
    print("Done")
    pass



if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"])
    parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"])
    parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"])
    parser.add_argument('--seed', default=None, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629
    parser.add_argument('--do_dense', default=True, type=bool, help='whether to train dense model. Should be true for all datasets except (maybe) ImageNet')
    parser.add_argument('--cropGT', default=False, type=bool,
                        help='Whether to crop CUB/TravelingBirds based on GT Boundaries')
    parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567
    parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class')
    parser.add_argument('--img_size', default=448, type=int, help='Image size')
    parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets')
    args = parser.parse_args()
    main(args.dataset, args.arch, args.seed, args.model_type, args.do_dense,args.cropGT,  args.n_features, args.n_per_class, args.img_size, args.reduced_strides)