File size: 1,047 Bytes
9b896f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import torch

from FeatureDiversityLoss import FeatureDiversityLoss
from finetuning.utils import train_n_epochs
from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
from sparsification.sldd import compute_sldd_feature_selection_and_assignment
from train import train, test
from training.optim import get_optimizer




def finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_per_class, n_features, ):
        feature_sel, weight, bias, mean, std = compute_sldd_feature_selection_and_assignment(model, train_loader,
                                                                                        test_loader,
                                                                                        log_dir, n_classes, seed,n_per_class, n_features)
        model.set_model_sldd(feature_sel, weight, mean, std, bias)
        model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
        return model