|
import logging |
|
import os |
|
import shutil |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from glm_saga.elasticnet import glm_saga |
|
from torch import nn |
|
|
|
from sparsification.FeatureSelection import FeatureSelectionFitting |
|
from sparsification import data_helpers |
|
from sparsification.utils import get_default_args, compute_features_and_metadata, select_in_loader, get_feature_loaders |
|
|
|
|
|
def get_glm_selection(feature_loaders, metadata, args, num_classes, device, n_features_to_select, folder): |
|
num_features = metadata["X"]["num_features"][0] |
|
fittingClass = FeatureSelectionFitting(num_features, num_classes, args, 0.8, |
|
n_features_to_select, |
|
0.1,folder, |
|
lookback=3, tol=1e-4, |
|
epsilon=1,) |
|
to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) |
|
selected_features = torch.tensor([i for i in range(num_features) if i not in to_drop]) |
|
return selected_features |
|
|
|
|
|
def compute_feature_selection_and_assignment(model, train_loader, test_loader, log_folder,num_classes, seed, select_features = 50): |
|
feature_loaders, metadata, device,args = get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ) |
|
|
|
if os.path.exists(log_folder / f"SlDD_Selection_{select_features}.pt"): |
|
feature_selection = torch.load(log_folder / f"SlDD_Selection_{select_features}.pt") |
|
else: |
|
used_features = model.linear.weight.shape[1] |
|
if used_features != select_features: |
|
selection_folder = log_folder / "sldd_selection" |
|
feature_selection = get_glm_selection(feature_loaders, metadata, args, |
|
num_classes, |
|
device,select_features, selection_folder |
|
) |
|
else: |
|
feature_selection = model.linear.selection |
|
torch.save(feature_selection, log_folder / f"SlDD_Selection_{select_features}.pt") |
|
feature_loaders = select_in_loader(feature_loaders, feature_selection) |
|
mean, std = metadata["X"]["mean"], metadata["X"]["std"] |
|
mean_to_pass_in = mean |
|
std_to_pass_in = std |
|
if len(mean) != feature_selection.shape[0]: |
|
mean_to_pass_in = mean[feature_selection] |
|
std_to_pass_in = std[feature_selection] |
|
|
|
sparse_matrices, biases = fit_glm(log_folder, mean_to_pass_in, std_to_pass_in, feature_loaders, num_classes, select_features) |
|
|
|
return feature_selection, sparse_matrices, biases, mean, std |
|
|
|
|
|
def fit_glm(log_dir,mean, std , feature_loaders, num_classes, select_features = 50): |
|
output_folder = log_dir / "glm_path" |
|
if not output_folder.exists() or len(list(output_folder.iterdir())) != 102: |
|
shutil.rmtree(output_folder, ignore_errors=True) |
|
output_folder.mkdir(exist_ok=True, parents=True) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
linear = nn.Linear(select_features, num_classes).to(device) |
|
for p in [linear.weight, linear.bias]: |
|
p.data.zero_() |
|
print("Preparing normalization preprocess and indexed dataloader") |
|
metadata = {"X": {"mean": mean, "std": std},} |
|
preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'], |
|
metadata=metadata, |
|
device=linear.weight.device) |
|
|
|
print("Calculating the regularization path") |
|
mpl_logger = logging.getLogger("matplotlib") |
|
mpl_logger.setLevel(logging.WARNING) |
|
params = glm_saga(linear, |
|
feature_loaders['train'], |
|
0.1, |
|
2000, |
|
0.99, k=100, |
|
val_loader=feature_loaders['val'], |
|
test_loader=feature_loaders['test'], |
|
n_classes=num_classes, |
|
checkpoint=str(output_folder), |
|
verbose=200, |
|
tol=1e-4, |
|
lookbehind=5, |
|
lr_decay_factor=1, |
|
group=False, |
|
epsilon=0.001, |
|
metadata=None, |
|
preprocess=preprocess, ) |
|
results = load_glm(output_folder) |
|
sparse_matrices = results["weights"] |
|
biases = results["biases"] |
|
|
|
return sparse_matrices, biases |
|
|
|
def load_glm(result_dir): |
|
Nlambda = max([int(f.split('params')[1].split('.pth')[0]) |
|
for f in os.listdir(result_dir) if 'params' in f]) + 1 |
|
|
|
print(f"Loading regularization path of length {Nlambda}") |
|
|
|
params_dict = {i: torch.load(os.path.join(result_dir, f"params{i}.pth"), |
|
map_location=torch.device('cpu')) for i in range(Nlambda)} |
|
|
|
regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)] |
|
weights = [params_dict[i]['weight'] for i in range(Nlambda)] |
|
biases = [params_dict[i]['bias'] for i in range(Nlambda)] |
|
|
|
metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []} |
|
|
|
for k in metrics.keys(): |
|
for i in range(Nlambda): |
|
metrics[k].append(params_dict[i]['metrics'][k]) |
|
metrics[k] = 100 * np.stack(metrics[k]) |
|
metrics = pd.DataFrame(metrics) |
|
metrics = metrics.rename(columns={'acc_tr': 'acc_train'}) |
|
|
|
|
|
|
|
sparsity = np.array([torch.sum(w != 0, dim=1).numpy() for w in weights]) |
|
|
|
return {'metrics': metrics, |
|
'regularization_strengths': regularization_strengths, |
|
'weights': weights, |
|
'biases': biases, |
|
'sparsity': sparsity, |
|
'weight_dense': weights[-1], |
|
'bias_dense': biases[-1]} |
|
|