|
from argparse import ArgumentParser |
|
import logging |
|
import math |
|
import os.path |
|
import sys |
|
import time |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader |
|
from torch import nn |
|
|
|
import torch as ch |
|
|
|
from sparsification.utils import safe_zip |
|
|
|
|
|
|
|
""" |
|
This would need glm_saga to run |
|
usage to select 50 features with parameters as in paper: |
|
metadata contains information about the precomputed train features in feature_loaders |
|
args contains the default arguments for glm-saga, as described at the bottom |
|
def get_glm_to_zero(feature_loaders, metadata, args, num_classes, device, train_ds, Ntotal): |
|
num_features = metadata["X"]["num_features"][0] |
|
fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8, |
|
50, |
|
True,0.1, |
|
lookback=3, tol=1e-4, |
|
epsilon=1,) |
|
to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) |
|
return to_drop |
|
|
|
to_drop is then used to remove the features from the downstream fitting and finetuning. |
|
""" |
|
|
|
|
|
class FeatureSelectionFitting: |
|
def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac,out_dir, lookback=None, tol=None, |
|
epsilon=None): |
|
""" |
|
This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks) |
|
The function extended_mask_max covers the changed operator, |
|
Args: |
|
n_features: |
|
n_classes: |
|
args: default args for glmsaga |
|
selalpha: alpha for elastic net |
|
nKeep: target number features |
|
lam_fac: discount factor for lambda |
|
parameters of glmsaga |
|
lookback: |
|
tol: |
|
epsilon: |
|
""" |
|
self.selected_features = torch.zeros(n_features, dtype=torch.bool) |
|
self.num_features = n_features |
|
self.selalpha = selalpha |
|
self.lam_Fac = lam_fac |
|
self.out_dir = out_dir |
|
self.n_classes = n_classes |
|
self.nKeep = nKeep |
|
self.args = self.extend_args(args, lookback, tol, epsilon) |
|
|
|
|
|
def extended_mask_max(self, greater_to_keep, thresh): |
|
prev = greater_to_keep[self.selected_features] |
|
greater_to_keep[self.selected_features] = torch.min(greater_to_keep) |
|
max_entry = torch.argmax(greater_to_keep) |
|
greater_to_keep[self.selected_features] = prev |
|
mask = torch.zeros_like(greater_to_keep) |
|
mask[max_entry] = 1 |
|
final_mask = (greater_to_keep > thresh) |
|
final_mask = final_mask * mask |
|
allowed_to_keep = torch.logical_or(self.selected_features, final_mask) |
|
return allowed_to_keep |
|
|
|
def extend_args(self, args, lookback, tol, epsilon): |
|
for key, entry in safe_zip(["lookbehind", "tol", |
|
"lr_decay_factor", ], [lookback, tol, epsilon]): |
|
if entry is not None: |
|
setattr(args, key, entry) |
|
return args |
|
|
|
|
|
|
|
|
|
def group_threshold(self, weight, lam): |
|
norm = weight.norm(p=2, dim=0) + 1e-6 |
|
|
|
return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam) |
|
|
|
|
|
|
|
|
|
def group_threshold_with_shrinkage(self, x, alpha, beta): |
|
y = self.group_threshold(x, alpha) |
|
return y / (1 + beta) |
|
|
|
def threshold(self, weight_new, lr, lam): |
|
alpha = self.selalpha |
|
if alpha == 1: |
|
|
|
weight_new = self.group_threshold(weight_new, lr * lam * alpha) |
|
else: |
|
|
|
weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha, |
|
lr * lam * (1 - alpha)) |
|
return weight_new |
|
|
|
|
|
|
|
|
|
|
|
def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None, |
|
state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4, |
|
preprocess=None, lookbehind=None, family='multinomial', logger=None): |
|
if logger is None: |
|
logger = print |
|
with ch.no_grad(): |
|
weight, bias = list(linear.parameters()) |
|
if table_device is None: |
|
table_device = weight.device |
|
|
|
|
|
|
|
if n_ex is None: |
|
n_ex = sum(tensors[0].size(0) for tensors in loader) |
|
if n_classes is None: |
|
if family == 'multinomial': |
|
n_classes = max(tensors[1].max().item() for tensors in loader) + 1 |
|
elif family == 'gaussian': |
|
for batch in loader: |
|
y = batch[1] |
|
break |
|
n_classes = y.size(1) |
|
|
|
|
|
if state is None: |
|
a_table = ch.zeros(n_ex, n_classes).to(table_device) |
|
w_grad_avg = ch.zeros_like(weight).to(weight.device) |
|
b_grad_avg = ch.zeros_like(bias).to(weight.device) |
|
else: |
|
a_table = state["a_table"].to(table_device) |
|
w_grad_avg = state["w_grad_avg"].to(weight.device) |
|
b_grad_avg = state["b_grad_avg"].to(weight.device) |
|
|
|
obj_history = [] |
|
obj_best = None |
|
nni = 0 |
|
for t in range(nepochs): |
|
total_loss = 0 |
|
for n_batch, batch in enumerate(loader): |
|
if len(batch) == 3: |
|
X, y, idx = batch |
|
w = None |
|
elif len(batch) == 4: |
|
X, y, w, idx = batch |
|
else: |
|
raise ValueError( |
|
f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}") |
|
|
|
if preprocess is not None: |
|
device = get_device(preprocess) |
|
with ch.no_grad(): |
|
X = preprocess(X.to(device)) |
|
X = X.to(weight.device) |
|
out = linear(X) |
|
|
|
|
|
|
|
if family == 'multinomial': |
|
if w is None: |
|
loss = F.cross_entropy(out, y.to(weight.device), reduction='mean') |
|
else: |
|
loss = F.cross_entropy(out, y.to(weight.device), reduction='none') |
|
loss = (loss * w).mean() |
|
I = ch.eye(linear.weight.size(0)) |
|
target = I[y].to(weight.device) |
|
|
|
|
|
logits = F.softmax(linear(X)) |
|
elif family == 'gaussian': |
|
if w is None: |
|
loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean') |
|
else: |
|
loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none') |
|
loss = (loss * (w.unsqueeze(1))).mean() |
|
target = y |
|
|
|
|
|
logits = linear(X) |
|
else: |
|
raise ValueError(f"Unknown family: {family}") |
|
total_loss += loss.item() * X.size(0) |
|
|
|
|
|
a = logits - target |
|
if w is not None: |
|
a = a * w.unsqueeze(1) |
|
a_prev = a_table[idx].to(weight.device) |
|
|
|
|
|
w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0) |
|
w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0) |
|
w_saga = w_grad - w_grad_prev + w_grad_avg |
|
weight_new = weight - lr * w_saga |
|
weight_new = self.threshold(weight_new, lr, lam) |
|
|
|
b_grad = a.mean(0) |
|
b_grad_prev = a_prev.mean(0) |
|
b_saga = b_grad - b_grad_prev + b_grad_avg |
|
bias_new = bias - lr * b_saga |
|
|
|
|
|
a_table[idx] = a.to(table_device) |
|
w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex) |
|
b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex) |
|
|
|
if lookbehind is None: |
|
dw = (weight_new - weight).norm(p=2) |
|
db = (bias_new - bias).norm(p=2) |
|
criteria = ch.sqrt(dw ** 2 + db ** 2) |
|
|
|
if criteria.item() <= tol: |
|
return { |
|
"a_table": a_table.cpu(), |
|
"w_grad_avg": w_grad_avg.cpu(), |
|
"b_grad_avg": b_grad_avg.cpu() |
|
} |
|
|
|
weight.data = weight_new |
|
bias.data = bias_new |
|
|
|
saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * ( |
|
weight ** 2).sum() |
|
|
|
|
|
obj_history.append(saga_obj.item()) |
|
if obj_best is None or saga_obj.item() + tol < obj_best: |
|
obj_best = saga_obj.item() |
|
nni = 0 |
|
else: |
|
nni += 1 |
|
|
|
|
|
criteria = lookbehind is not None and (nni >= lookbehind) |
|
|
|
nnz = (weight.abs() > 1e-5).sum().item() |
|
total = weight.numel() |
|
if verbose and (t % verbose) == 0: |
|
if lookbehind is None: |
|
logger( |
|
f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}") |
|
else: |
|
logger( |
|
f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}") |
|
|
|
if lookbehind is not None and criteria: |
|
logger( |
|
f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]") |
|
return { |
|
"a_table": a_table.cpu(), |
|
"w_grad_avg": w_grad_avg.cpu(), |
|
"b_grad_avg": b_grad_avg.cpu() |
|
} |
|
|
|
logger(f"did not converge at {nepochs} iterations (criteria {criteria})") |
|
return { |
|
"a_table": a_table.cpu(), |
|
"w_grad_avg": w_grad_avg.cpu(), |
|
"b_grad_avg": b_grad_avg.cpu() |
|
} |
|
|
|
def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries, |
|
table_device=None, preprocess=None, group=False, |
|
verbose=None, state=None, n_ex=None, n_classes=None, |
|
tol=1e-4, epsilon=0.001, k=100, checkpoint=None, |
|
do_zero=True, lr_decay_factor=1, metadata=None, |
|
val_loader=None, test_loader=None, lookbehind=None, |
|
family='multinomial', encoder=None, tot_tries=1): |
|
if encoder is not None: |
|
warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning) |
|
preprocess = encoder |
|
device = get_device(linear) |
|
checkpoint = self.out_dir |
|
if preprocess is not None and (device != get_device(preprocess)): |
|
raise ValueError( |
|
f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})") |
|
|
|
if metadata is not None: |
|
if n_ex is None: |
|
n_ex = metadata['X']['num_examples'] |
|
if n_classes is None: |
|
n_classes = metadata['y']['num_classes'] |
|
lam_fac = (1 + (tries - 1) / tot_tries) |
|
print("Using lam_fac ", lam_fac) |
|
max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata, |
|
family=family) / max( |
|
0.001, alpha) * lam_fac |
|
group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata, |
|
family=family) / max( |
|
0.001, alpha) * lam_fac |
|
min_lam = epsilon * max_lam |
|
group_min_lam = epsilon * group_lam |
|
|
|
lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k) |
|
lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k) |
|
found = False |
|
if do_zero: |
|
lams = ch.cat([lams, lams.new_zeros(1)]) |
|
lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]]) |
|
|
|
path = [] |
|
best_val_loss = float('inf') |
|
|
|
if checkpoint is not None: |
|
os.makedirs(checkpoint, exist_ok=True) |
|
|
|
file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log')) |
|
stdout_handler = logging.StreamHandler(sys.stdout) |
|
handlers = [file_handler, stdout_handler] |
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='[%(asctime)s] %(levelname)s - %(message)s', |
|
handlers=handlers |
|
) |
|
logger = logging.getLogger('glm_saga').info |
|
else: |
|
logger = print |
|
while self.selected_features.sum() < self.nKeep: |
|
n_feature_to_keep = self.selected_features.sum() |
|
for i, (lam, lr) in enumerate(zip(lams, lrs)): |
|
lam = lam * self.lam_Fac |
|
start_time = time.time() |
|
self.selected_features = self.selected_features.to(device) |
|
state = self.train_saga(linear, loader, lr, nepochs, lam, alpha, |
|
table_device=table_device, preprocess=preprocess, group=group, verbose=verbose, |
|
state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind, |
|
family=family, logger=logger) |
|
|
|
with ch.no_grad(): |
|
loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess, |
|
family=family) |
|
loss, acc = loss.item(), acc.item() |
|
|
|
loss_val, acc_val = -1, -1 |
|
if val_loader: |
|
loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha, |
|
preprocess=preprocess, |
|
family=family) |
|
loss_val, acc_val = loss_val.item(), acc_val.item() |
|
|
|
loss_test, acc_test = -1, -1 |
|
if test_loader: |
|
loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha, |
|
preprocess=preprocess, family=family) |
|
loss_test, acc_test = loss_test.item(), acc_test.item() |
|
|
|
params = { |
|
"lam": lam, |
|
"lr": lr, |
|
"alpha": alpha, |
|
"time": time.time() - start_time, |
|
"loss": loss, |
|
"metrics": { |
|
"loss_tr": loss, |
|
"acc_tr": acc, |
|
"loss_val": loss_val, |
|
"acc_val": acc_val, |
|
"loss_test": loss_test, |
|
"acc_test": acc_test, |
|
}, |
|
"weight": linear.weight.detach().cpu().clone(), |
|
"bias": linear.bias.detach().cpu().clone() |
|
|
|
} |
|
path.append(params) |
|
if loss_val is not None and loss_val < best_val_loss: |
|
best_val_loss = loss_val |
|
best_params = params |
|
found = True |
|
nnz = (linear.weight.abs() > 1e-5).sum().item() |
|
total = linear.weight.numel() |
|
if family == 'multinomial': |
|
logger( |
|
f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") |
|
elif family == 'gaussian': |
|
logger( |
|
f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") |
|
|
|
if self.check_new_feature(linear.weight): |
|
if checkpoint is not None: |
|
ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth")) |
|
break |
|
if found: |
|
return { |
|
'path': path, |
|
'best': best_params, |
|
'state': state |
|
} |
|
else: |
|
return False |
|
|
|
def check_new_feature(self, weight): |
|
|
|
copied_weight = torch.tensor(weight.cpu()) |
|
used_features = torch.unique( |
|
torch.nonzero(copied_weight)[:, 1]) |
|
if len(used_features) > 0: |
|
new_set = set(used_features.tolist()) |
|
old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist()) |
|
diff = new_set - old_set |
|
if len(diff) > 0: |
|
self.selected_features[used_features] = True |
|
return True |
|
return False |
|
|
|
def fit(self, feature_loaders, metadata, device): |
|
|
|
print("Initializing linear model...") |
|
linear = nn.Linear(self.num_features, self.n_classes).to(device) |
|
for p in [linear.weight, linear.bias]: |
|
p.data.zero_() |
|
|
|
print("Preparing normalization preprocess and indexed dataloader") |
|
preprocess = NormalizedRepresentation(feature_loaders['train'], |
|
metadata=metadata, |
|
device=linear.weight.device) |
|
|
|
print("Calculating the regularization path") |
|
mpl_logger = logging.getLogger("matplotlib") |
|
mpl_logger.setLevel(logging.WARNING) |
|
selected_features = self.glm_saga(linear, |
|
feature_loaders['train'], |
|
self.args.lr, |
|
self.args.max_epochs, |
|
self.selalpha, 0, 1, |
|
val_loader=feature_loaders['val'], |
|
test_loader=feature_loaders['test'], |
|
n_classes=self.n_classes, |
|
verbose=self.args.verbose, |
|
tol=self.args.tol, |
|
lookbehind=self.args.lookbehind, |
|
lr_decay_factor=self.args.lr_decay_factor, |
|
group=True, |
|
epsilon=self.args.lam_factor, |
|
metadata=metadata, |
|
preprocess=preprocess, tot_tries=1) |
|
to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0] |
|
test_acc = selected_features["path"][-1]["metrics"]["acc_test"] |
|
torch.set_grad_enabled(True) |
|
return to_drop, test_acc |
|
|
|
|
|
class NormalizedRepresentation(ch.nn.Module): |
|
def __init__(self, loader, metadata, device='cuda', tol=1e-5): |
|
super(NormalizedRepresentation, self).__init__() |
|
|
|
assert metadata is not None |
|
self.device = device |
|
self.mu = metadata['X']['mean'] |
|
self.sigma = ch.clamp(metadata['X']['std'], tol) |
|
|
|
def forward(self, X): |
|
return (X - self.mu.to(self.device)) / self.sigma.to(self.device) |
|
|
|
|
|
|
|
|
|
|