Spaces:
Running
Running
import logging | |
import os | |
import shutil | |
import numpy as np | |
import pandas as pd | |
import torch | |
from glm_saga.elasticnet import glm_saga | |
from torch import nn | |
from sparsification.FeatureSelection import FeatureSelectionFitting | |
from sparsification import data_helpers | |
from sparsification.utils import get_default_args, compute_features_and_metadata, select_in_loader, get_feature_loaders | |
def get_glm_selection(feature_loaders, metadata, args, num_classes, device, n_features_to_select, folder): | |
num_features = metadata["X"]["num_features"][0] | |
fittingClass = FeatureSelectionFitting(num_features, num_classes, args, 0.8, | |
n_features_to_select, | |
0.1,folder, | |
lookback=3, tol=1e-4, | |
epsilon=1,) | |
to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) | |
selected_features = torch.tensor([i for i in range(num_features) if i not in to_drop]) | |
return selected_features | |
def compute_feature_selection_and_assignment(model, train_loader, test_loader, log_folder,num_classes, seed, select_features = 50): | |
feature_loaders, metadata, device,args = get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ) | |
if os.path.exists(log_folder / f"SlDD_Selection_{select_features}.pt"): | |
feature_selection = torch.load(log_folder / f"SlDD_Selection_{select_features}.pt") | |
else: | |
used_features = model.linear.weight.shape[1] | |
if used_features != select_features: | |
selection_folder = log_folder / "sldd_selection" # overwrite with None to prevent saving | |
feature_selection = get_glm_selection(feature_loaders, metadata, args, | |
num_classes, | |
device,select_features, selection_folder | |
) | |
else: | |
feature_selection = model.linear.selection | |
torch.save(feature_selection, log_folder / f"SlDD_Selection_{select_features}.pt") | |
feature_loaders = select_in_loader(feature_loaders, feature_selection) | |
mean, std = metadata["X"]["mean"], metadata["X"]["std"] | |
mean_to_pass_in = mean | |
std_to_pass_in = std | |
if len(mean) != feature_selection.shape[0]: | |
mean_to_pass_in = mean[feature_selection] | |
std_to_pass_in = std[feature_selection] | |
sparse_matrices, biases = fit_glm(log_folder, mean_to_pass_in, std_to_pass_in, feature_loaders, num_classes, select_features) | |
return feature_selection, sparse_matrices, biases, mean, std | |
def fit_glm(log_dir,mean, std , feature_loaders, num_classes, select_features = 50): | |
output_folder = log_dir / "glm_path" | |
if not output_folder.exists() or len(list(output_folder.iterdir())) != 102: | |
shutil.rmtree(output_folder, ignore_errors=True) | |
output_folder.mkdir(exist_ok=True, parents=True) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
linear = nn.Linear(select_features, num_classes).to(device) | |
for p in [linear.weight, linear.bias]: | |
p.data.zero_() | |
print("Preparing normalization preprocess and indexed dataloader") | |
metadata = {"X": {"mean": mean, "std": std},} | |
preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'], | |
metadata=metadata, | |
device=linear.weight.device) | |
print("Calculating the regularization path") | |
mpl_logger = logging.getLogger("matplotlib") | |
mpl_logger.setLevel(logging.WARNING) | |
params = glm_saga(linear, | |
feature_loaders['train'], | |
0.1, | |
2000, | |
0.99, k=100, | |
val_loader=feature_loaders['val'], | |
test_loader=feature_loaders['test'], | |
n_classes=num_classes, | |
checkpoint=str(output_folder), | |
verbose=200, | |
tol=1e-4, # Change for ImageNet | |
lookbehind=5, | |
lr_decay_factor=1, | |
group=False, | |
epsilon=0.001, | |
metadata=None, # To let it be recomputed | |
preprocess=preprocess, ) | |
results = load_glm(output_folder) | |
sparse_matrices = results["weights"] | |
biases = results["biases"] | |
return sparse_matrices, biases | |
def load_glm(result_dir): | |
Nlambda = max([int(f.split('params')[1].split('.pth')[0]) | |
for f in os.listdir(result_dir) if 'params' in f]) + 1 | |
print(f"Loading regularization path of length {Nlambda}") | |
params_dict = {i: torch.load(os.path.join(result_dir, f"params{i}.pth"), | |
map_location=torch.device('cpu')) for i in range(Nlambda)} | |
regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)] | |
weights = [params_dict[i]['weight'] for i in range(Nlambda)] | |
biases = [params_dict[i]['bias'] for i in range(Nlambda)] | |
metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []} | |
for k in metrics.keys(): | |
for i in range(Nlambda): | |
metrics[k].append(params_dict[i]['metrics'][k]) | |
metrics[k] = 100 * np.stack(metrics[k]) | |
metrics = pd.DataFrame(metrics) | |
metrics = metrics.rename(columns={'acc_tr': 'acc_train'}) | |
# weights_stacked = ch.stack(weights) | |
# sparsity = ch.sum(weights_stacked != 0, dim=2).numpy() | |
sparsity = np.array([torch.sum(w != 0, dim=1).numpy() for w in weights]) | |
return {'metrics': metrics, | |
'regularization_strengths': regularization_strengths, | |
'weights': weights, | |
'biases': biases, | |
'sparsity': sparsity, | |
'weight_dense': weights[-1], | |
'bias_dense': biases[-1]} | |