from argparse import ArgumentParser import logging import math import os.path import sys import time import warnings import numpy as np import torch import torch.nn.functional as F from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader from torch import nn import torch as ch from sparsification.utils import safe_zip # TODO checkout this change: Marks changes to the group version of glmsaga """ This would need glm_saga to run usage to select 50 features with parameters as in paper: metadata contains information about the precomputed train features in feature_loaders args contains the default arguments for glm-saga, as described at the bottom def get_glm_to_zero(feature_loaders, metadata, args, num_classes, device, train_ds, Ntotal): num_features = metadata["X"]["num_features"][0] fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8, 50, True,0.1, lookback=3, tol=1e-4, epsilon=1,) to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) return to_drop to_drop is then used to remove the features from the downstream fitting and finetuning. """ class FeatureSelectionFitting: def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac,out_dir, lookback=None, tol=None, epsilon=None): """ This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks) The function extended_mask_max covers the changed operator, Args: n_features: n_classes: args: default args for glmsaga selalpha: alpha for elastic net nKeep: target number features lam_fac: discount factor for lambda parameters of glmsaga lookback: tol: epsilon: """ self.selected_features = torch.zeros(n_features, dtype=torch.bool) self.num_features = n_features self.selalpha = selalpha self.lam_Fac = lam_fac self.out_dir = out_dir self.n_classes = n_classes self.nKeep = nKeep self.args = self.extend_args(args, lookback, tol, epsilon) # Extended Proximal Operator for Feature Selection def extended_mask_max(self, greater_to_keep, thresh): prev = greater_to_keep[self.selected_features] greater_to_keep[self.selected_features] = torch.min(greater_to_keep) max_entry = torch.argmax(greater_to_keep) greater_to_keep[self.selected_features] = prev mask = torch.zeros_like(greater_to_keep) mask[max_entry] = 1 final_mask = (greater_to_keep > thresh) final_mask = final_mask * mask allowed_to_keep = torch.logical_or(self.selected_features, final_mask) return allowed_to_keep def extend_args(self, args, lookback, tol, epsilon): for key, entry in safe_zip(["lookbehind", "tol", "lr_decay_factor", ], [lookback, tol, epsilon]): if entry is not None: setattr(args, key, entry) return args # Grouped L1 regularization # proximal operator for f(weight) = lam * \|weight\|_2 # where the 2-norm is taken columnwise def group_threshold(self, weight, lam): norm = weight.norm(p=2, dim=0) + 1e-6 # print(ch.sum((norm > lam))) return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam) # Elastic net regularization with group sparsity # proximal operator for f(x) = alpha * \|x\|_1 + beta * \|x\|_2^2 # where the 2-norm is taken columnwise def group_threshold_with_shrinkage(self, x, alpha, beta): y = self.group_threshold(x, alpha) return y / (1 + beta) def threshold(self, weight_new, lr, lam): alpha = self.selalpha if alpha == 1: # Pure L1 regularization weight_new = self.group_threshold(weight_new, lr * lam * alpha) else: # Elastic net regularization weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha, lr * lam * (1 - alpha)) return weight_new # Train an elastic GLM with proximal SAGA # Since SAGA stores a scalar for each example-class pair, either pass # the number of examples and number of classes or calculate it with an # initial pass over the loaders def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None, state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4, preprocess=None, lookbehind=None, family='multinomial', logger=None): if logger is None: logger = print with ch.no_grad(): weight, bias = list(linear.parameters()) if table_device is None: table_device = weight.device # get total number of examples and initialize scalars # for computing the gradients if n_ex is None: n_ex = sum(tensors[0].size(0) for tensors in loader) if n_classes is None: if family == 'multinomial': n_classes = max(tensors[1].max().item() for tensors in loader) + 1 elif family == 'gaussian': for batch in loader: y = batch[1] break n_classes = y.size(1) # Storage for scalar gradients and averages if state is None: a_table = ch.zeros(n_ex, n_classes).to(table_device) w_grad_avg = ch.zeros_like(weight).to(weight.device) b_grad_avg = ch.zeros_like(bias).to(weight.device) else: a_table = state["a_table"].to(table_device) w_grad_avg = state["w_grad_avg"].to(weight.device) b_grad_avg = state["b_grad_avg"].to(weight.device) obj_history = [] obj_best = None nni = 0 for t in range(nepochs): total_loss = 0 for n_batch, batch in enumerate(loader): if len(batch) == 3: X, y, idx = batch w = None elif len(batch) == 4: X, y, w, idx = batch else: raise ValueError( f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}") if preprocess is not None: device = get_device(preprocess) with ch.no_grad(): X = preprocess(X.to(device)) X = X.to(weight.device) out = linear(X) # split gradient on only the cross entropy term # for efficient storage of gradient information if family == 'multinomial': if w is None: loss = F.cross_entropy(out, y.to(weight.device), reduction='mean') else: loss = F.cross_entropy(out, y.to(weight.device), reduction='none') loss = (loss * w).mean() I = ch.eye(linear.weight.size(0)) target = I[y].to(weight.device) # change to OHE # Calculate new scalar gradient logits = F.softmax(linear(X)) elif family == 'gaussian': if w is None: loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean') else: loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none') loss = (loss * (w.unsqueeze(1))).mean() target = y # Calculate new scalar gradient logits = linear(X) else: raise ValueError(f"Unknown family: {family}") total_loss += loss.item() * X.size(0) # BS x NUM_CLASSES a = logits - target if w is not None: a = a * w.unsqueeze(1) a_prev = a_table[idx].to(weight.device) # weight parameter w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0) w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0) w_saga = w_grad - w_grad_prev + w_grad_avg weight_new = weight - lr * w_saga weight_new = self.threshold(weight_new, lr, lam) # bias parameter b_grad = a.mean(0) b_grad_prev = a_prev.mean(0) b_saga = b_grad - b_grad_prev + b_grad_avg bias_new = bias - lr * b_saga # update table and averages a_table[idx] = a.to(table_device) w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex) b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex) if lookbehind is None: dw = (weight_new - weight).norm(p=2) db = (bias_new - bias).norm(p=2) criteria = ch.sqrt(dw ** 2 + db ** 2) if criteria.item() <= tol: return { "a_table": a_table.cpu(), "w_grad_avg": w_grad_avg.cpu(), "b_grad_avg": b_grad_avg.cpu() } weight.data = weight_new bias.data = bias_new saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * ( weight ** 2).sum() # save amount of improvement obj_history.append(saga_obj.item()) if obj_best is None or saga_obj.item() + tol < obj_best: obj_best = saga_obj.item() nni = 0 else: nni += 1 # Stop if no progress for lookbehind iterationsd:]) criteria = lookbehind is not None and (nni >= lookbehind) nnz = (weight.abs() > 1e-5).sum().item() total = weight.numel() if verbose and (t % verbose) == 0: if lookbehind is None: logger( f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}") else: logger( f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}") if lookbehind is not None and criteria: logger( f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]") return { "a_table": a_table.cpu(), "w_grad_avg": w_grad_avg.cpu(), "b_grad_avg": b_grad_avg.cpu() } logger(f"did not converge at {nepochs} iterations (criteria {criteria})") return { "a_table": a_table.cpu(), "w_grad_avg": w_grad_avg.cpu(), "b_grad_avg": b_grad_avg.cpu() } def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries, table_device=None, preprocess=None, group=False, verbose=None, state=None, n_ex=None, n_classes=None, tol=1e-4, epsilon=0.001, k=100, checkpoint=None, do_zero=True, lr_decay_factor=1, metadata=None, val_loader=None, test_loader=None, lookbehind=None, family='multinomial', encoder=None, tot_tries=1): if encoder is not None: warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning) preprocess = encoder device = get_device(linear) checkpoint = self.out_dir if preprocess is not None and (device != get_device(preprocess)): raise ValueError( f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})") if metadata is not None: if n_ex is None: n_ex = metadata['X']['num_examples'] if n_classes is None: n_classes = metadata['y']['num_classes'] lam_fac = (1 + (tries - 1) / tot_tries) print("Using lam_fac ", lam_fac) max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata, family=family) / max( 0.001, alpha) * lam_fac group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata, family=family) / max( 0.001, alpha) * lam_fac min_lam = epsilon * max_lam group_min_lam = epsilon * group_lam # logspace is base 10 but log is base e so use log10 lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k) lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k) found = False if do_zero: lams = ch.cat([lams, lams.new_zeros(1)]) lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]]) path = [] best_val_loss = float('inf') if checkpoint is not None: os.makedirs(checkpoint, exist_ok=True) file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log')) stdout_handler = logging.StreamHandler(sys.stdout) handlers = [file_handler, stdout_handler] logging.basicConfig( level=logging.DEBUG, format='[%(asctime)s] %(levelname)s - %(message)s', handlers=handlers ) logger = logging.getLogger('glm_saga').info else: logger = print while self.selected_features.sum() < self.nKeep: # TODO checkout this change, one iteration per feature n_feature_to_keep = self.selected_features.sum() for i, (lam, lr) in enumerate(zip(lams, lrs)): lam = lam * self.lam_Fac start_time = time.time() self.selected_features = self.selected_features.to(device) state = self.train_saga(linear, loader, lr, nepochs, lam, alpha, table_device=table_device, preprocess=preprocess, group=group, verbose=verbose, state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind, family=family, logger=logger) with ch.no_grad(): loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess, family=family) loss, acc = loss.item(), acc.item() loss_val, acc_val = -1, -1 if val_loader: loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha, preprocess=preprocess, family=family) loss_val, acc_val = loss_val.item(), acc_val.item() loss_test, acc_test = -1, -1 if test_loader: loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha, preprocess=preprocess, family=family) loss_test, acc_test = loss_test.item(), acc_test.item() params = { "lam": lam, "lr": lr, "alpha": alpha, "time": time.time() - start_time, "loss": loss, "metrics": { "loss_tr": loss, "acc_tr": acc, "loss_val": loss_val, "acc_val": acc_val, "loss_test": loss_test, "acc_test": acc_test, }, "weight": linear.weight.detach().cpu().clone(), "bias": linear.bias.detach().cpu().clone() } path.append(params) if loss_val is not None and loss_val < best_val_loss: best_val_loss = loss_val best_params = params found = True nnz = (linear.weight.abs() > 1e-5).sum().item() total = linear.weight.numel() if family == 'multinomial': logger( f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") elif family == 'gaussian': logger( f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") if self.check_new_feature(linear.weight): # TODO checkout this change, canceling if new feature is used if checkpoint is not None: ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth")) break if found: return { 'path': path, 'best': best_params, 'state': state } else: return False def check_new_feature(self, weight): # TODO checkout this change, checking if new feature is used copied_weight = torch.tensor(weight.cpu()) used_features = torch.unique( torch.nonzero(copied_weight)[:, 1]) if len(used_features) > 0: new_set = set(used_features.tolist()) old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist()) diff = new_set - old_set if len(diff) > 0: self.selected_features[used_features] = True return True return False def fit(self, feature_loaders, metadata, device): # TODO checkout this change, glm saga code slightly adapted to return to_drop print("Initializing linear model...") linear = nn.Linear(self.num_features, self.n_classes).to(device) for p in [linear.weight, linear.bias]: p.data.zero_() print("Preparing normalization preprocess and indexed dataloader") preprocess = NormalizedRepresentation(feature_loaders['train'], metadata=metadata, device=linear.weight.device) print("Calculating the regularization path") mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) selected_features = self.glm_saga(linear, feature_loaders['train'], self.args.lr, self.args.max_epochs, self.selalpha, 0, 1, val_loader=feature_loaders['val'], test_loader=feature_loaders['test'], n_classes=self.n_classes, verbose=self.args.verbose, tol=self.args.tol, lookbehind=self.args.lookbehind, lr_decay_factor=self.args.lr_decay_factor, group=True, epsilon=self.args.lam_factor, metadata=metadata, preprocess=preprocess, tot_tries=1) to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0] test_acc = selected_features["path"][-1]["metrics"]["acc_test"] torch.set_grad_enabled(True) return to_drop, test_acc class NormalizedRepresentation(ch.nn.Module): def __init__(self, loader, metadata, device='cuda', tol=1e-5): super(NormalizedRepresentation, self).__init__() assert metadata is not None self.device = device self.mu = metadata['X']['mean'] self.sigma = ch.clamp(metadata['X']['std'], tol) def forward(self, X): return (X - self.mu.to(self.device)) / self.sigma.to(self.device)