import numpy as np import torch from FeatureDiversityLoss import FeatureDiversityLoss from finetuning.utils import train_n_epochs from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment from sparsification.sldd import compute_sldd_feature_selection_and_assignment from train import train, test from training.optim import get_optimizer def finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_per_class, n_features, ): feature_sel, weight, bias, mean, std = compute_sldd_feature_selection_and_assignment(model, train_loader, test_loader, log_dir, n_classes, seed,n_per_class, n_features) model.set_model_sldd(feature_sel, weight, mean, std, bias) model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader) return model