Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment | |
| def compute_qsenn_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed,n_features, per_class = 5): | |
| feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader, | |
| test_loader, | |
| log_folder, num_classes, seed, n_features) | |
| weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices[:-1], biases[:-1], per_class) # Last one in regularisation path has no regularisation | |
| print(f"Number of nonzeros in weight matrix: {torch.sum(weight_sparse != 0)}") | |
| return feature_sel, weight_sparse, bias_sparse, mean, std | |
| def get_sparsified_weights_for_factor(weights, biases, factor,): | |
| no_reg_result_mat, no_reg_result_bias = weights[-1], biases[-1] | |
| goal_nonzeros = factor * no_reg_result_mat.shape[0] | |
| values = no_reg_result_mat.flatten() | |
| values = values[values != 0] | |
| values = -(torch.sort(-torch.abs(values))[0]) | |
| if goal_nonzeros < len(values): | |
| threshold = (values[int(goal_nonzeros) - 1] + values[int(goal_nonzeros)]) / 2 | |
| else: | |
| threshold = values[-1] | |
| max_val = torch.max(torch.abs(values)) | |
| weight_sparse = discretize_2_bins_to_threshold(no_reg_result_mat, threshold, max_val) | |
| sel_idx = len(weights) - 1 | |
| positive_weights_per_class = np.array(torch.sum(weight_sparse > 0, dim=1)) | |
| negative_weights_per_class = np.array(torch.sum(weight_sparse < 0, dim=1)) | |
| total_weight_count_per_class = positive_weights_per_class - negative_weights_per_class | |
| max_bias = torch.max(torch.abs(biases[sel_idx])) | |
| bias_sparse = torch.ones_like(biases[sel_idx]) * max_bias | |
| diff_n_weight = total_weight_count_per_class - np.min(total_weight_count_per_class) | |
| steps = np.max(diff_n_weight) | |
| single_step = 2 * max_bias / steps | |
| bias_sparse = bias_sparse - torch.tensor(diff_n_weight) * single_step | |
| bias_sparse = torch.clamp(bias_sparse, -max_bias, max_bias) | |
| return weight_sparse, bias_sparse | |
| def discretize_2_bins_to_threshold(data, treshold, max): | |
| boundaries = torch.tensor([-max, -treshold, treshold, max], device=data.device) | |
| bucketized_tensor = torch.bucketize(data, boundaries) | |
| means = torch.tensor([-max, 0, max], device=data.device) | |
| for i in range(len(means)): | |
| if means[i] == 0: | |
| break | |
| positive_index = int(len(means) / 2 + 1) + i | |
| positive_bucket = data[bucketized_tensor == positive_index + 1] | |
| negative_bucket = data[bucketized_tensor == i + 1] | |
| sum = 0 | |
| total = 0 | |
| for bucket in [positive_bucket, negative_bucket]: | |
| if len(bucket) == 0: | |
| continue | |
| sum += torch.sum(torch.abs(bucket)) | |
| total += len(bucket) | |
| if total == 0: | |
| continue | |
| avg = sum / total | |
| means[i] = -avg | |
| means[positive_index] = avg | |
| discretized_tensor = means.cpu()[bucketized_tensor.cpu() - 1].to(bucketized_tensor.device) | |
| return discretized_tensor |