|
import os |
|
|
|
import torch |
|
|
|
from finetuning.utils import train_n_epochs |
|
from sparsification.qsenn import compute_qsenn_feature_selection_and_assignment |
|
|
|
|
|
def finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule ,n_features, n_per_class): |
|
for iteration_epoch in range(4): |
|
print(f"Starting iteration epoch {iteration_epoch}") |
|
this_log_dir = log_dir / f"iteration_epoch_{iteration_epoch}" |
|
this_log_dir.mkdir(parents=True, exist_ok=True) |
|
feature_sel, sparse_layer,bias_sparse, current_mean, current_std = compute_qsenn_feature_selection_and_assignment(model, train_loader, |
|
test_loader, |
|
this_log_dir, n_classes, seed, n_features, n_per_class) |
|
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) |
|
if os.path.exists(this_log_dir / "trained_model.pth"): |
|
model.load_state_dict(torch.load(this_log_dir / "trained_model.pth")) |
|
_ = optimization_schedule.get_params() |
|
continue |
|
|
|
model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader) |
|
torch.save(model.state_dict(), this_log_dir / "trained_model.pth") |
|
print(f"Finished iteration epoch {iteration_epoch}") |
|
return model |
|
|
|
|
|
|
|
|
|
|