Spaces:
Sleeping
Sleeping
from argparse import ArgumentParser | |
import logging | |
import math | |
import os.path | |
import sys | |
import time | |
import warnings | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader | |
from torch import nn | |
import torch as ch | |
from sparsification.utils import safe_zip | |
# TODO checkout this change: Marks changes to the group version of glmsaga | |
""" | |
This would need glm_saga to run | |
usage to select 50 features with parameters as in paper: | |
metadata contains information about the precomputed train features in feature_loaders | |
args contains the default arguments for glm-saga, as described at the bottom | |
def get_glm_to_zero(feature_loaders, metadata, args, num_classes, device, train_ds, Ntotal): | |
num_features = metadata["X"]["num_features"][0] | |
fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8, | |
50, | |
True,0.1, | |
lookback=3, tol=1e-4, | |
epsilon=1,) | |
to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) | |
return to_drop | |
to_drop is then used to remove the features from the downstream fitting and finetuning. | |
""" | |
class FeatureSelectionFitting: | |
def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac,out_dir, lookback=None, tol=None, | |
epsilon=None): | |
""" | |
This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks) | |
The function extended_mask_max covers the changed operator, | |
Args: | |
n_features: | |
n_classes: | |
args: default args for glmsaga | |
selalpha: alpha for elastic net | |
nKeep: target number features | |
lam_fac: discount factor for lambda | |
parameters of glmsaga | |
lookback: | |
tol: | |
epsilon: | |
""" | |
self.selected_features = torch.zeros(n_features, dtype=torch.bool) | |
self.num_features = n_features | |
self.selalpha = selalpha | |
self.lam_Fac = lam_fac | |
self.out_dir = out_dir | |
self.n_classes = n_classes | |
self.nKeep = nKeep | |
self.args = self.extend_args(args, lookback, tol, epsilon) | |
# Extended Proximal Operator for Feature Selection | |
def extended_mask_max(self, greater_to_keep, thresh): | |
prev = greater_to_keep[self.selected_features] | |
greater_to_keep[self.selected_features] = torch.min(greater_to_keep) | |
max_entry = torch.argmax(greater_to_keep) | |
greater_to_keep[self.selected_features] = prev | |
mask = torch.zeros_like(greater_to_keep) | |
mask[max_entry] = 1 | |
final_mask = (greater_to_keep > thresh) | |
final_mask = final_mask * mask | |
allowed_to_keep = torch.logical_or(self.selected_features, final_mask) | |
return allowed_to_keep | |
def extend_args(self, args, lookback, tol, epsilon): | |
for key, entry in safe_zip(["lookbehind", "tol", | |
"lr_decay_factor", ], [lookback, tol, epsilon]): | |
if entry is not None: | |
setattr(args, key, entry) | |
return args | |
# Grouped L1 regularization | |
# proximal operator for f(weight) = lam * \|weight\|_2 | |
# where the 2-norm is taken columnwise | |
def group_threshold(self, weight, lam): | |
norm = weight.norm(p=2, dim=0) + 1e-6 | |
# print(ch.sum((norm > lam))) | |
return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam) | |
# Elastic net regularization with group sparsity | |
# proximal operator for f(x) = alpha * \|x\|_1 + beta * \|x\|_2^2 | |
# where the 2-norm is taken columnwise | |
def group_threshold_with_shrinkage(self, x, alpha, beta): | |
y = self.group_threshold(x, alpha) | |
return y / (1 + beta) | |
def threshold(self, weight_new, lr, lam): | |
alpha = self.selalpha | |
if alpha == 1: | |
# Pure L1 regularization | |
weight_new = self.group_threshold(weight_new, lr * lam * alpha) | |
else: | |
# Elastic net regularization | |
weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha, | |
lr * lam * (1 - alpha)) | |
return weight_new | |
# Train an elastic GLM with proximal SAGA | |
# Since SAGA stores a scalar for each example-class pair, either pass | |
# the number of examples and number of classes or calculate it with an | |
# initial pass over the loaders | |
def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None, | |
state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4, | |
preprocess=None, lookbehind=None, family='multinomial', logger=None): | |
if logger is None: | |
logger = print | |
with ch.no_grad(): | |
weight, bias = list(linear.parameters()) | |
if table_device is None: | |
table_device = weight.device | |
# get total number of examples and initialize scalars | |
# for computing the gradients | |
if n_ex is None: | |
n_ex = sum(tensors[0].size(0) for tensors in loader) | |
if n_classes is None: | |
if family == 'multinomial': | |
n_classes = max(tensors[1].max().item() for tensors in loader) + 1 | |
elif family == 'gaussian': | |
for batch in loader: | |
y = batch[1] | |
break | |
n_classes = y.size(1) | |
# Storage for scalar gradients and averages | |
if state is None: | |
a_table = ch.zeros(n_ex, n_classes).to(table_device) | |
w_grad_avg = ch.zeros_like(weight).to(weight.device) | |
b_grad_avg = ch.zeros_like(bias).to(weight.device) | |
else: | |
a_table = state["a_table"].to(table_device) | |
w_grad_avg = state["w_grad_avg"].to(weight.device) | |
b_grad_avg = state["b_grad_avg"].to(weight.device) | |
obj_history = [] | |
obj_best = None | |
nni = 0 | |
for t in range(nepochs): | |
total_loss = 0 | |
for n_batch, batch in enumerate(loader): | |
if len(batch) == 3: | |
X, y, idx = batch | |
w = None | |
elif len(batch) == 4: | |
X, y, w, idx = batch | |
else: | |
raise ValueError( | |
f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}") | |
if preprocess is not None: | |
device = get_device(preprocess) | |
with ch.no_grad(): | |
X = preprocess(X.to(device)) | |
X = X.to(weight.device) | |
out = linear(X) | |
# split gradient on only the cross entropy term | |
# for efficient storage of gradient information | |
if family == 'multinomial': | |
if w is None: | |
loss = F.cross_entropy(out, y.to(weight.device), reduction='mean') | |
else: | |
loss = F.cross_entropy(out, y.to(weight.device), reduction='none') | |
loss = (loss * w).mean() | |
I = ch.eye(linear.weight.size(0)) | |
target = I[y].to(weight.device) # change to OHE | |
# Calculate new scalar gradient | |
logits = F.softmax(linear(X)) | |
elif family == 'gaussian': | |
if w is None: | |
loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean') | |
else: | |
loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none') | |
loss = (loss * (w.unsqueeze(1))).mean() | |
target = y | |
# Calculate new scalar gradient | |
logits = linear(X) | |
else: | |
raise ValueError(f"Unknown family: {family}") | |
total_loss += loss.item() * X.size(0) | |
# BS x NUM_CLASSES | |
a = logits - target | |
if w is not None: | |
a = a * w.unsqueeze(1) | |
a_prev = a_table[idx].to(weight.device) | |
# weight parameter | |
w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0) | |
w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0) | |
w_saga = w_grad - w_grad_prev + w_grad_avg | |
weight_new = weight - lr * w_saga | |
weight_new = self.threshold(weight_new, lr, lam) | |
# bias parameter | |
b_grad = a.mean(0) | |
b_grad_prev = a_prev.mean(0) | |
b_saga = b_grad - b_grad_prev + b_grad_avg | |
bias_new = bias - lr * b_saga | |
# update table and averages | |
a_table[idx] = a.to(table_device) | |
w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex) | |
b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex) | |
if lookbehind is None: | |
dw = (weight_new - weight).norm(p=2) | |
db = (bias_new - bias).norm(p=2) | |
criteria = ch.sqrt(dw ** 2 + db ** 2) | |
if criteria.item() <= tol: | |
return { | |
"a_table": a_table.cpu(), | |
"w_grad_avg": w_grad_avg.cpu(), | |
"b_grad_avg": b_grad_avg.cpu() | |
} | |
weight.data = weight_new | |
bias.data = bias_new | |
saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * ( | |
weight ** 2).sum() | |
# save amount of improvement | |
obj_history.append(saga_obj.item()) | |
if obj_best is None or saga_obj.item() + tol < obj_best: | |
obj_best = saga_obj.item() | |
nni = 0 | |
else: | |
nni += 1 | |
# Stop if no progress for lookbehind iterationsd:]) | |
criteria = lookbehind is not None and (nni >= lookbehind) | |
nnz = (weight.abs() > 1e-5).sum().item() | |
total = weight.numel() | |
if verbose and (t % verbose) == 0: | |
if lookbehind is None: | |
logger( | |
f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}") | |
else: | |
logger( | |
f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}") | |
if lookbehind is not None and criteria: | |
logger( | |
f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]") | |
return { | |
"a_table": a_table.cpu(), | |
"w_grad_avg": w_grad_avg.cpu(), | |
"b_grad_avg": b_grad_avg.cpu() | |
} | |
logger(f"did not converge at {nepochs} iterations (criteria {criteria})") | |
return { | |
"a_table": a_table.cpu(), | |
"w_grad_avg": w_grad_avg.cpu(), | |
"b_grad_avg": b_grad_avg.cpu() | |
} | |
def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries, | |
table_device=None, preprocess=None, group=False, | |
verbose=None, state=None, n_ex=None, n_classes=None, | |
tol=1e-4, epsilon=0.001, k=100, checkpoint=None, | |
do_zero=True, lr_decay_factor=1, metadata=None, | |
val_loader=None, test_loader=None, lookbehind=None, | |
family='multinomial', encoder=None, tot_tries=1): | |
if encoder is not None: | |
warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning) | |
preprocess = encoder | |
device = get_device(linear) | |
checkpoint = self.out_dir | |
if preprocess is not None and (device != get_device(preprocess)): | |
raise ValueError( | |
f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})") | |
if metadata is not None: | |
if n_ex is None: | |
n_ex = metadata['X']['num_examples'] | |
if n_classes is None: | |
n_classes = metadata['y']['num_classes'] | |
lam_fac = (1 + (tries - 1) / tot_tries) | |
print("Using lam_fac ", lam_fac) | |
max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata, | |
family=family) / max( | |
0.001, alpha) * lam_fac | |
group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata, | |
family=family) / max( | |
0.001, alpha) * lam_fac | |
min_lam = epsilon * max_lam | |
group_min_lam = epsilon * group_lam | |
# logspace is base 10 but log is base e so use log10 | |
lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k) | |
lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k) | |
found = False | |
if do_zero: | |
lams = ch.cat([lams, lams.new_zeros(1)]) | |
lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]]) | |
path = [] | |
best_val_loss = float('inf') | |
if checkpoint is not None: | |
os.makedirs(checkpoint, exist_ok=True) | |
file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log')) | |
stdout_handler = logging.StreamHandler(sys.stdout) | |
handlers = [file_handler, stdout_handler] | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='[%(asctime)s] %(levelname)s - %(message)s', | |
handlers=handlers | |
) | |
logger = logging.getLogger('glm_saga').info | |
else: | |
logger = print | |
while self.selected_features.sum() < self.nKeep: # TODO checkout this change, one iteration per feature | |
n_feature_to_keep = self.selected_features.sum() | |
for i, (lam, lr) in enumerate(zip(lams, lrs)): | |
lam = lam * self.lam_Fac | |
start_time = time.time() | |
self.selected_features = self.selected_features.to(device) | |
state = self.train_saga(linear, loader, lr, nepochs, lam, alpha, | |
table_device=table_device, preprocess=preprocess, group=group, verbose=verbose, | |
state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind, | |
family=family, logger=logger) | |
with ch.no_grad(): | |
loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess, | |
family=family) | |
loss, acc = loss.item(), acc.item() | |
loss_val, acc_val = -1, -1 | |
if val_loader: | |
loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha, | |
preprocess=preprocess, | |
family=family) | |
loss_val, acc_val = loss_val.item(), acc_val.item() | |
loss_test, acc_test = -1, -1 | |
if test_loader: | |
loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha, | |
preprocess=preprocess, family=family) | |
loss_test, acc_test = loss_test.item(), acc_test.item() | |
params = { | |
"lam": lam, | |
"lr": lr, | |
"alpha": alpha, | |
"time": time.time() - start_time, | |
"loss": loss, | |
"metrics": { | |
"loss_tr": loss, | |
"acc_tr": acc, | |
"loss_val": loss_val, | |
"acc_val": acc_val, | |
"loss_test": loss_test, | |
"acc_test": acc_test, | |
}, | |
"weight": linear.weight.detach().cpu().clone(), | |
"bias": linear.bias.detach().cpu().clone() | |
} | |
path.append(params) | |
if loss_val is not None and loss_val < best_val_loss: | |
best_val_loss = loss_val | |
best_params = params | |
found = True | |
nnz = (linear.weight.abs() > 1e-5).sum().item() | |
total = linear.weight.numel() | |
if family == 'multinomial': | |
logger( | |
f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") | |
elif family == 'gaussian': | |
logger( | |
f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") | |
if self.check_new_feature(linear.weight): # TODO checkout this change, canceling if new feature is used | |
if checkpoint is not None: | |
ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth")) | |
break | |
if found: | |
return { | |
'path': path, | |
'best': best_params, | |
'state': state | |
} | |
else: | |
return False | |
def check_new_feature(self, weight): | |
# TODO checkout this change, checking if new feature is used | |
copied_weight = torch.tensor(weight.cpu()) | |
used_features = torch.unique( | |
torch.nonzero(copied_weight)[:, 1]) | |
if len(used_features) > 0: | |
new_set = set(used_features.tolist()) | |
old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist()) | |
diff = new_set - old_set | |
if len(diff) > 0: | |
self.selected_features[used_features] = True | |
return True | |
return False | |
def fit(self, feature_loaders, metadata, device): | |
# TODO checkout this change, glm saga code slightly adapted to return to_drop | |
print("Initializing linear model...") | |
linear = nn.Linear(self.num_features, self.n_classes).to(device) | |
for p in [linear.weight, linear.bias]: | |
p.data.zero_() | |
print("Preparing normalization preprocess and indexed dataloader") | |
preprocess = NormalizedRepresentation(feature_loaders['train'], | |
metadata=metadata, | |
device=linear.weight.device) | |
print("Calculating the regularization path") | |
mpl_logger = logging.getLogger("matplotlib") | |
mpl_logger.setLevel(logging.WARNING) | |
selected_features = self.glm_saga(linear, | |
feature_loaders['train'], | |
self.args.lr, | |
self.args.max_epochs, | |
self.selalpha, 0, 1, | |
val_loader=feature_loaders['val'], | |
test_loader=feature_loaders['test'], | |
n_classes=self.n_classes, | |
verbose=self.args.verbose, | |
tol=self.args.tol, | |
lookbehind=self.args.lookbehind, | |
lr_decay_factor=self.args.lr_decay_factor, | |
group=True, | |
epsilon=self.args.lam_factor, | |
metadata=metadata, | |
preprocess=preprocess, tot_tries=1) | |
to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0] | |
test_acc = selected_features["path"][-1]["metrics"]["acc_test"] | |
torch.set_grad_enabled(True) | |
return to_drop, test_acc | |
class NormalizedRepresentation(ch.nn.Module): | |
def __init__(self, loader, metadata, device='cuda', tol=1e-5): | |
super(NormalizedRepresentation, self).__init__() | |
assert metadata is not None | |
self.device = device | |
self.mu = metadata['X']['mean'] | |
self.sigma = ch.clamp(metadata['X']['std'], tol) | |
def forward(self, X): | |
return (X - self.mu.to(self.device)) / self.sigma.to(self.device) | |