Spaces:
Sleeping
Sleeping
import torch | |
from tqdm import tqdm | |
def get_metrics_for_model(train_loader, test_loader, model, metric_evaluator): | |
(features_train, feature_maps_train, outputs_train, features_test, feature_maps_test, | |
outputs_test, labels) = [], [], [], [], [], [], [] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.eval() | |
model = model.to(device) | |
training_transforms = train_loader.dataset.transform | |
train_loader.dataset.transform = test_loader.dataset.transform # Use test transform for train | |
train_loader = torch.utils.data.DataLoader(train_loader.dataset, batch_size=100, shuffle=False) # Turn off shuffling | |
print("Going in get metrics") | |
linear_matrix = model.linear.weight | |
entries = torch.nonzero(linear_matrix) | |
rel_features = torch.unique(entries[:, 1]) | |
with torch.no_grad(): | |
iterator = tqdm(enumerate(train_loader), total=len(train_loader)) | |
for batch_idx, (data, target) in iterator: | |
xs1 = data.to("cuda") | |
output, feature_maps, final_features = model(xs1, with_feature_maps=True, with_final_features=True,) | |
outputs_train.append(output.to("cpu")) | |
features_train.append(final_features.to("cpu")) | |
labels.append(target.to("cpu")) | |
total = 0 | |
correct = 0 | |
iterator = tqdm(enumerate(test_loader), total=len(test_loader)) | |
for batch_idx, (data, target) in iterator: | |
xs1 = data.to("cuda") | |
output, feature_maps, final_features = model(xs1, with_feature_maps=True, | |
with_final_features=True, ) | |
feature_maps_test.append(feature_maps[:, rel_features].to("cpu")) | |
outputs_test.append(output.to("cpu")) | |
total += target.size(0) | |
_, predicted = output.max(1) | |
correct += predicted.eq(target.to("cuda")).sum().item() | |
print("test accuracy: ", correct / total) | |
features_train = torch.cat(features_train) | |
outputs_train = torch.cat(outputs_train) | |
feature_maps_test = torch.cat(feature_maps_test) | |
outputs_test = torch.cat(outputs_test) | |
labels = torch.cat(labels) | |
linear_matrix = linear_matrix[:, rel_features] | |
print("Shape of linear matrix: ", linear_matrix.shape) | |
all_metrics_dict = metric_evaluator(features_train, outputs_train, | |
feature_maps_test, | |
outputs_test, linear_matrix, labels) | |
result_dict = {"Accuracy": correct / total, "NFfeatures": linear_matrix.shape[1], | |
"PerClass": torch.nonzero(linear_matrix).shape[0] / linear_matrix.shape[0], | |
} | |
result_dict.update(all_metrics_dict) | |
print(result_dict) | |
# Reset Train transforms | |
train_loader.dataset.transform = training_transforms | |
return result_dict | |