diff --git a/DIC.py b/DIC.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd67353e13c054ea320882506624ac0e2050a91
--- /dev/null
+++ b/DIC.py
@@ -0,0 +1,17 @@
+import torch
+from pathlib import Path
+
+
+dir=Path.home() / f"tmp/resnet50/CUB2011/123456/"
+dic=torch.load(dir/ f"SlDD_Selection_50.pt")
+
+print (dic)
+
+#if 'linear.selection' in dic.keys():
+ #print("key 'linear.selection' exist")
+#else:
+ #print("no such key")
+
+
+
+
diff --git a/FeatureDiversityLoss.py b/FeatureDiversityLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..be5745ae71dbe298244271c3a942c80c2b3e9867
--- /dev/null
+++ b/FeatureDiversityLoss.py
@@ -0,0 +1,59 @@
+import torch
+from torch import nn
+
+"""
+Feature Diversity Loss:
+Usage to replicate paper:
+Call
+loss_function = FeatureDiversityLoss(0.196, linear)
+to inititalize loss with linear layer of model.
+At each mini batch get feature maps (Output of final convolutional layer) and add to Loss:
+loss += loss_function(feature_maps, outputs)
+"""
+
+
+class FeatureDiversityLoss(nn.Module):
+ def __init__(self, scaling_factor, linear):
+ super().__init__()
+ self.scaling_factor = scaling_factor #* 0
+ print("Scaling Factor: ", self.scaling_factor)
+ self.linearLayer = linear
+
+ def initialize(self, linearLayer):
+ self.linearLayer = linearLayer
+
+ def get_weights(self, outputs):
+ weight_matrix = self.linearLayer.weight
+ weight_matrix = torch.abs(weight_matrix)
+ top_classes = torch.argmax(outputs, dim=1)
+ relevant_weights = weight_matrix[top_classes]
+ return relevant_weights
+
+ def forward(self, feature_maps, outputs):
+ relevant_weights = self.get_weights(outputs)
+ relevant_weights = norm_vector(relevant_weights)
+ feature_maps = preserve_avg_func(feature_maps)
+ flattened_feature_maps = feature_maps.flatten(2)
+ batch, features, map_size = flattened_feature_maps.size()
+ relevant_feature_maps = flattened_feature_maps * relevant_weights[..., None]
+ diversity_loss = torch.sum(
+ torch.amax(relevant_feature_maps, dim=1))
+ return -diversity_loss / batch * self.scaling_factor
+
+
+def norm_vector(x):
+ return x / (torch.norm(x, dim=1) + 1e-5)[:, None]
+
+
+def preserve_avg_func(x):
+ avgs = torch.mean(x, dim=[2, 3])
+ max_avgs = torch.max(avgs, dim=1)[0]
+ scaling_factor = avgs / torch.clamp(max_avgs[..., None], min=1e-6)
+ softmaxed_maps = softmax_feature_maps(x)
+ scaled_maps = softmaxed_maps * scaling_factor[..., None, None]
+ return scaled_maps
+
+
+def softmax_feature_maps(x):
+ return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x)
+
diff --git a/__pycache__/get_data.cpython-310.pyc b/__pycache__/get_data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db7e1ece40eba3119dbfd6595168c1599a3847bb
Binary files /dev/null and b/__pycache__/get_data.cpython-310.pyc differ
diff --git a/__pycache__/load_model.cpython-310.pyc b/__pycache__/load_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dc4bd3e12a89cd304e7c33e8cf190a0353a8698
Binary files /dev/null and b/__pycache__/load_model.cpython-310.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8dab90f6b16cfbfcba731f336feb59e92c9f0b7
--- /dev/null
+++ b/app.py
@@ -0,0 +1,143 @@
+import gradio as gr
+from load_model import extract_sel_mean_std_bias_assignemnt
+from pathlib import Path
+from architectures.model_mapping import get_model
+from configs.dataset_params import dataset_constants
+import torch
+import torchvision.transforms as transforms
+import pandas as pd
+import cv2
+import numpy as np
+
+def overlapping_features_on_input(model,output, feature_maps, input, target):
+ W=model.linear.layer.weight
+ output=output.detach().cpu().numpy()
+ feature_maps=feature_maps.detach().cpu().numpy().squeeze()
+
+ if target !=None:
+ label=target
+ else:
+ label=np.argmax(output)+1
+
+ Interpretable_Selection= W[label,:]
+ print("W",Interpretable_Selection)
+ input_np=np.array(input)
+ h,w= input.shape[:2]
+ print("h,w:",h,w)
+ Interpretable_Features=[]
+ Feature_image_list=[]
+ for S in range(len(Interpretable_Selection)):
+ if Interpretable_Selection[S] > 0:
+ Interpretable_Features.append(feature_maps[S])
+ Feature_image=cv2.resize(feature_maps[S],(w,h))
+ Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255
+ Feature_image=Feature_image.astype(np.uint8)
+ Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET)
+ Feature_image=0.3*Feature_image+0.7*input_np
+ Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8)
+ Feature_image_list.append(Feature_image)
+ #path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg"
+ #cv2.imwrite(path_to_featureimage,Feature_image)
+ print("len of Features:",len(Interpretable_Features))
+
+ return Feature_image_list
+
+
+def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None):
+ n_classes = dataset_constants[dataset]["num_classes"]
+
+ model = get_model(arch, n_classes, reduced_strides)
+ tr=transforms.ToTensor()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if folder is None:
+ folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
+
+ state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
+ selection= torch.load(folder / f"SlDD_Selection_50.pt")
+ state_dict['linear.selection']=selection
+
+ feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
+ model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
+ model.load_state_dict(state_dict)
+
+ input = tr(input)
+ input= input.unsqueeze(0)
+ input= input.to(device)
+ model = model.to(device)
+ output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True)
+ print("final features:",final_features)
+ output=output.detach().cpu().numpy()
+ output= np.argmax(output)+1
+
+
+ print("outputclass:",output)
+ data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/")
+ labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target'])
+ namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name'])
+ classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
+ options_output=labels[labels['target']==output]
+ options_output=options_output.sample(1)
+ others=labels[labels['target']!=output]
+ options_others=others.sample(3)
+ options = pd.concat([options_others, options_output], ignore_index=True)
+ shuffled_options = options.sample(frac=1).reset_index(drop=True)
+ print("shuffled:",shuffled_options)
+ op=[]
+
+ for i in shuffled_options['img_id']:
+ print(i)
+ filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0]
+ targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0]
+ print("targets",targets)
+ print("name",filenames)
+
+ classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0]
+ print(data_dir/f"images/{filenames}")
+
+ op_img=cv2.imread(data_dir/f"images/{filenames}")
+
+ op_images=tr(op_img)
+ op_images=op_images.unsqueeze(0)
+ op_images=op_images.to(device)
+ OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False)
+ print("OP:",OP,
+ "feature_maps_op:",feature_maps_op.shape)
+ opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets)
+ op+=opt
+
+ return op
+
+def post_next_image(op):
+ if len(op)<=1:
+ return [],None, "all done, thank you!"
+ else:
+ op=op[1:len(op)]
+ return op,op[0], "Is this feature also in your input?"
+
+def get_features_on_interface(input):
+ op=genreate_intepriable_output(input,dataset="CUB2011",
+ arch="resnet50",seed=123456,
+ model_type="qsenn", n_features = 50,n_per_class=5,
+ img_size=448, reduced_strides=False, folder = None)
+ return op, op[0],"Is this feature also in your input?",gr.update(interactive=False)
+
+
+with gr.Blocks() as demo:
+
+ gr.Markdown("
Interiable Bird Classification
")
+ image_input=gr.Image()
+ image_output=gr.Image()
+ text_output=gr.Markdown()
+ but_generate=gr.Button("Get some interpriable Features")
+ but_feedback_y=gr.Button("Yes")
+ but_feedback_n=gr.Button("No")
+ image_list = gr.State([])
+ but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate])
+ but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
+ but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
+
+demo.launch()
+
+
+
+
\ No newline at end of file
diff --git a/architectures/FinalLayer.py b/architectures/FinalLayer.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1a55a667c462ec8f256f9d28aefdc5e77d6cae
--- /dev/null
+++ b/architectures/FinalLayer.py
@@ -0,0 +1,36 @@
+import torch
+from torch import nn
+
+from architectures.SLDDLevel import SLDDLevel
+
+
+class FinalLayer():
+ def __init__(self, num_classes, n_features):
+ super().__init__()
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
+ self.linear = nn.Linear(n_features, num_classes)
+ self.featureDropout = torch.nn.Dropout(0.2)
+ self.selection = None
+
+ def transform_output(self, feature_maps, with_feature_maps,
+ with_final_features):
+ if self.selection is not None:
+ feature_maps = feature_maps[:, self.selection]
+ x = self.avgpool(feature_maps)
+ pre_out = torch.flatten(x, 1)
+ final_features = self.featureDropout(pre_out)
+ final = self.linear(final_features)
+ final = [final]
+ if with_feature_maps:
+ final.append(feature_maps)
+ if with_final_features:
+ final.append(final_features)
+ if len(final) == 1:
+ final = final[0]
+ return final
+
+
+ def set_model_sldd(self, selection, weight, mean, std, bias = None):
+ self.selection = selection
+ self.linear = SLDDLevel(selection, weight, mean, std, bias)
+ self.featureDropout = torch.nn.Dropout(0.1)
\ No newline at end of file
diff --git a/architectures/SLDDLevel.py b/architectures/SLDDLevel.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc214c88f384690d29bda97d7ed82a8c01e866da
--- /dev/null
+++ b/architectures/SLDDLevel.py
@@ -0,0 +1,37 @@
+import torch.nn
+
+
+class SLDDLevel(torch.nn.Module):
+ def __init__(self, selection, weight_at_selection,mean, std, bias=None):
+ super().__init__()
+ self.register_buffer('selection', torch.tensor(selection, dtype=torch.long))
+ num_classes, n_features = weight_at_selection.shape
+ selected_mean = mean
+ selected_std = std
+ if len(selected_mean) != len(selection):
+ selected_mean = selected_mean[selection]
+ selected_std = selected_std[selection]
+ self.mean = torch.nn.Parameter(selected_mean)
+ self.std = torch.nn.Parameter(selected_std)
+ if bias is not None:
+ self.layer = torch.nn.Linear(n_features, num_classes)
+ self.layer.bias = torch.nn.Parameter(bias, requires_grad=False)
+ else:
+ self.layer = torch.nn.Linear(n_features, num_classes, bias=False)
+ self.layer.weight = torch.nn.Parameter(weight_at_selection, requires_grad=False)
+
+ @property
+ def weight(self):
+ return self.layer.weight
+
+ @property
+ def bias(self):
+ if self.layer.bias is None:
+ return torch.zeros(self.layer.out_features)
+ else:
+ return self.layer.bias
+
+
+ def forward(self, input):
+ input = (input - self.mean) / torch.clamp(self.std, min=1e-6)
+ return self.layer(input)
diff --git a/architectures/__pycache__/FinalLayer.cpython-310.pyc b/architectures/__pycache__/FinalLayer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a61e00b4a63ecca65185c2f0157f001dbead798
Binary files /dev/null and b/architectures/__pycache__/FinalLayer.cpython-310.pyc differ
diff --git a/architectures/__pycache__/SLDDLevel.cpython-310.pyc b/architectures/__pycache__/SLDDLevel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a5d50a5be90c92f0840009d5a78ce7a4a4821df
Binary files /dev/null and b/architectures/__pycache__/SLDDLevel.cpython-310.pyc differ
diff --git a/architectures/__pycache__/model_mapping.cpython-310.pyc b/architectures/__pycache__/model_mapping.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..515201d4169359545fb87981950d8c22a4181b40
Binary files /dev/null and b/architectures/__pycache__/model_mapping.cpython-310.pyc differ
diff --git a/architectures/__pycache__/resnet.cpython-310.pyc b/architectures/__pycache__/resnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93dccc8da3700ae7fdcbfe887f868ef16d935b95
Binary files /dev/null and b/architectures/__pycache__/resnet.cpython-310.pyc differ
diff --git a/architectures/__pycache__/utils.cpython-310.pyc b/architectures/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a224141ff65b31e88257f416c3f7bfc2aaafbfa
Binary files /dev/null and b/architectures/__pycache__/utils.cpython-310.pyc differ
diff --git a/architectures/model_mapping.py b/architectures/model_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..52df91009029b653b420ff03562616b2389eaa68
--- /dev/null
+++ b/architectures/model_mapping.py
@@ -0,0 +1,7 @@
+from architectures.resnet import resnet50
+
+
+def get_model(arch, num_classes, changed_strides=True):
+ if arch == "resnet50":
+ model = resnet50(True, num_classes=num_classes, changed_strides=changed_strides)
+ return model
\ No newline at end of file
diff --git a/architectures/resnet.py b/architectures/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaaa5d3c22e6ab85f9ac63b29462d20aec9594d3
--- /dev/null
+++ b/architectures/resnet.py
@@ -0,0 +1,420 @@
+import copy
+import time
+
+import torch
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+from torchvision.models import get_model
+
+# from scripts.modelExtensions.crossModelfunctions import init_experiment_stuff
+
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2',
+ 'wide_resnet50_3', 'wide_resnet50_4', 'wide_resnet50_5',
+ 'wide_resnet50_6', ]
+
+from architectures.FinalLayer import FinalLayer
+from architectures.utils import SequentialWithArgs
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, features=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+
+ def forward(self, x, no_relu=False):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+
+
+ out += identity
+
+ if no_relu:
+ return out
+ return self.relu(out)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None, features=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ if features is None:
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ else:
+ self.conv3 = conv1x1(width, features)
+ self.bn3 = norm_layer(features)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x, no_relu=False, early_exit=False):
+ identity = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ if no_relu:
+ return out
+ return self.relu(out)
+
+
+class ResNet(nn.Module, FinalLayer):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None, changed_strides=False,):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+ widths = [64, 128, 256, 512]
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.sstride = 2
+ if changed_strides:
+ self.sstride = 1
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=self.sstride,
+ dilate=replace_stride_with_dilation[1])
+ self.stride = 2
+
+ if changed_strides:
+ self.stride = 1
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=self.stride,
+ dilate=replace_stride_with_dilation[2])
+ FinalLayer.__init__(self, num_classes, 512 * block.expansion)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_block_f=None):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ krepeep = None
+ if last_block_f is not None and _ == blocks - 1:
+ krepeep = last_block_f
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer, features=krepeep))
+
+ return SequentialWithArgs(*layers)
+
+ def _forward(self, x, with_feature_maps=False, with_final_features=False):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ feature_maps = self.layer4(x, no_relu=True)
+ feature_maps = torch.functional.F.relu(feature_maps)
+ return self.transform_output( feature_maps, with_feature_maps,
+ with_final_features)
+
+ # Allow for accessing forward method in a inherited class
+ forward = _forward
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ if kwargs["num_classes"] == 1000:
+ state_dict["linear.weight"] = state_dict["fc.weight"]
+ state_dict["linear.bias"] = state_dict["fc.bias"]
+ model.load_state_dict(state_dict, strict=False)
+ return model
+
+
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_3(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-3 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 3
+ return _resnet('wide_resnet50_3', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_4(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-4 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 4
+ return _resnet('wide_resnet50_4', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_5(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-5 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 5
+ return _resnet('wide_resnet50_5', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_6(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-6 model
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 6
+ return _resnet('wide_resnet50_6', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/architectures/utils.py b/architectures/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed4cc78fc2799c675098bc73f0a9fb1719fb64b1
--- /dev/null
+++ b/architectures/utils.py
@@ -0,0 +1,17 @@
+import torch
+
+
+
+class SequentialWithArgs(torch.nn.Sequential):
+ def forward(self, input, *args, **kwargs):
+ vs = list(self._modules.values())
+ l = len(vs)
+ for i in range(l):
+ if i == l-1:
+ input = vs[i](input, *args, **kwargs)
+ else:
+ input = vs[i](input)
+ return input
+
+
+
diff --git a/configs/__pycache__/dataset_params.cpython-310.pyc b/configs/__pycache__/dataset_params.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f91c9298daab0217462fe2590a20695b80575e85
Binary files /dev/null and b/configs/__pycache__/dataset_params.cpython-310.pyc differ
diff --git a/configs/__pycache__/optim_params.cpython-310.pyc b/configs/__pycache__/optim_params.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0b3f696299233ff7e5b5d5936ea47d83c8f8b97
Binary files /dev/null and b/configs/__pycache__/optim_params.cpython-310.pyc differ
diff --git a/configs/architecture_params.py b/configs/architecture_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..21a5f2b7fb72dfbee2b4487cabca2d2e840ad938
--- /dev/null
+++ b/configs/architecture_params.py
@@ -0,0 +1 @@
+architecture_params = {"resnet50": {"beta":0.196}}
\ No newline at end of file
diff --git a/configs/dataset_params.py b/configs/dataset_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f227da674de51bf3c4ac0fe3a8faff2004775a6
--- /dev/null
+++ b/configs/dataset_params.py
@@ -0,0 +1,22 @@
+import torch
+
+from configs.optim_params import EvaluatedDict
+
+dataset_constants = {"CUB2011":{"num_classes":200},
+ "TravelingBirds":{"num_classes":200},
+ "ImageNet":{"num_classes":1000},
+ "StanfordCars":{"num_classes":196},
+ "FGVCAircraft": {"num_classes":100}}
+
+normalize_params = {"CUB2011":{"mean": torch.tensor([0.4853, 0.4964, 0.4295]),"std":torch.tensor([0.2300, 0.2258, 0.2625])},
+"TravelingBirds":{"mean": torch.tensor([0.4584, 0.4369, 0.3957]),"std":torch.tensor([0.2610, 0.2569, 0.2722])},
+ "ImageNet":{'mean': torch.tensor([0.485, 0.456, 0.406]),'std': torch.tensor([0.229, 0.224, 0.225])} ,
+"StanfordCars":{'mean': torch.tensor([0.4593, 0.4466, 0.4453]),'std': torch.tensor([0.2920, 0.2910, 0.2988])} ,
+ "FGVCAircraft":{'mean': torch.tensor([0.4827, 0.5130, 0.5352]),
+ 'std': torch.tensor([0.2236, 0.2170, 0.2478]),}
+ }
+
+
+dense_batch_size = EvaluatedDict({False: 16,True: 1024,}, lambda x: x == "ImageNet")
+
+ft_batch_size = EvaluatedDict({False: 16,True: 1024,}, lambda x: x == "ImageNet")# Untested
\ No newline at end of file
diff --git a/configs/optim_params.py b/configs/optim_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fad011caec798b4d51948b28b4d0885c414b59
--- /dev/null
+++ b/configs/optim_params.py
@@ -0,0 +1,22 @@
+# order: lr,weight_decay, step_lr, step_lr_gamma
+import math
+
+
+class EvaluatedDict:
+ def __init__(self, d, func):
+ self.dict = d
+ self.func = func
+
+ def __getitem__(self, key):
+ return self.dict[self.func(key)]
+
+dense_params = EvaluatedDict({False: [0.005, 0.0005, 30, 0.4, 150],True: [None,None,None,None,None],}, lambda x: x == "ImageNet")
+def calculate_lr_from_args( epochs, step_lr, start_lr, step_lr_decay):
+ # Gets the final learning rate after dense training with step_lr_schedule.
+ n_steps = math.floor((epochs - step_lr) / step_lr)
+ final_lr = start_lr * step_lr_decay ** n_steps
+ return final_lr
+
+ft_params =EvaluatedDict({False: [1e-4, 0.0005, 10, 0.4, 40],True:[[calculate_lr_from_args(150,30,0.005, 0.4), 0.0005, 10, 0.4, 40]]}, lambda x: x == "ImageNet")
+
+
diff --git a/configs/qsenn_training_params.py b/configs/qsenn_training_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca03c994ee04c47016c89357ff5d4953f634281
--- /dev/null
+++ b/configs/qsenn_training_params.py
@@ -0,0 +1,11 @@
+from configs.sldd_training_params import OptimizationScheduler
+
+
+class QSENNScheduler(OptimizationScheduler):
+ def get_params(self):
+ params = super().get_params()
+ if self.n_calls >= 2:
+ params[0] = params[0] * 0.9**(self.n_calls-2)
+ if 2 <= self.n_calls <= 4:
+ params[-2] = 10# Change num epochs to 10 for iterative finetuning
+ return params
diff --git a/configs/sldd_training_params.py b/configs/sldd_training_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a605602a1a399d0dd55e1f53d8cbaa8c5d73dc0
--- /dev/null
+++ b/configs/sldd_training_params.py
@@ -0,0 +1,17 @@
+from configs.optim_params import dense_params, ft_params
+
+
+class OptimizationScheduler:
+ def __init__(self, dataset):
+ self.dataset = dataset
+ self.n_calls = 0
+
+
+ def get_params(self):
+ if self.n_calls == 0: # Return Deńse Params
+ params = dense_params[self.dataset]+ [False]
+ else: # Return Finetuning Params
+ params = ft_params[self.dataset]+ [True]
+ self.n_calls += 1
+ return params
+
diff --git a/dataset_classes/__pycache__/cub200.cpython-310.pyc b/dataset_classes/__pycache__/cub200.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b65386a367820ce47ae2ecf095fb28397d58df2a
Binary files /dev/null and b/dataset_classes/__pycache__/cub200.cpython-310.pyc differ
diff --git a/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc b/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e54b5188bd94de8a82b0639237f4ddf557cd52a
Binary files /dev/null and b/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc differ
diff --git a/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc b/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7cc7e30e2027a3730c1873443049962b75998fb0
Binary files /dev/null and b/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc differ
diff --git a/dataset_classes/__pycache__/utils.cpython-310.pyc b/dataset_classes/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bb96db5a1c1e6f6cd7e6efc268cffd6d6e0004b
Binary files /dev/null and b/dataset_classes/__pycache__/utils.cpython-310.pyc differ
diff --git a/dataset_classes/cub200.py b/dataset_classes/cub200.py
new file mode 100644
index 0000000000000000000000000000000000000000..b59a933605948ed45365ccba82486c2c433d4173
--- /dev/null
+++ b/dataset_classes/cub200.py
@@ -0,0 +1,96 @@
+# Dataset should lie under /root/
+# root is currently set to ~/tmp/Datasets/CUB200
+# If cropped iamges, like for PIP-Net, ProtoPool, etc. are used, then the crop_root should be set to a folder containing the
+# cropped images in the expected structure, obtained by following ProtoTree's instructions.
+# https://github.com/M-Nauta/ProtoTree/blob/main/README.md#preprocessing-cub
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+from torch.utils.data import Dataset
+from torchvision.datasets.folder import default_loader
+
+from dataset_classes.utils import txt_load
+
+
+class CUB200Class(Dataset):
+ root = Path.home() / "tmp/Datasets/CUB200"
+ crop_root = Path.home() / "tmp/Datasets/PPCUB200"
+ base_folder = 'CUB_200_2011/images'
+ def __init__(self, train, transform, crop=True):
+ self.train = train
+ self.transform = transform
+ self.crop = crop
+ self._load_metadata()
+ self.loader = default_loader
+
+ if crop:
+ self.adapt_to_crop()
+
+ def _load_metadata(self):
+ images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
+ names=['img_id', 'filepath'])
+ image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
+ sep=' ', names=['img_id', 'target'])
+ train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ data = images.merge(image_class_labels, on='img_id')
+ self.data = data.merge(train_test_split, on='img_id')
+ if self.train:
+ self.data = self.data[self.data.is_training_img == 1]
+ else:
+ self.data = self.data[self.data.is_training_img == 0]
+
+ def __len__(self):
+ return len(self.data)
+
+ def adapt_to_crop(self):
+ # ds_name = [x for x in self.cropped_dict.keys() if x in self.root][0]
+ self.root = self.crop_root
+ folder_name = "train" if self.train else "test"
+ folder_name = folder_name + "_cropped"
+ self.base_folder = 'CUB_200_2011/' + folder_name
+
+ def __getitem__(self, idx):
+ sample = self.data.iloc[idx]
+ path = os.path.join(self.root, self.base_folder, sample.filepath)
+ target = sample.target - 1 # Targets start at 1 by default, so shift to 0
+ img = self.loader(path)
+ img = self.transform(img)
+ return img, target
+
+ @classmethod
+ def get_image_attribute_labels(self, train=False):
+ image_attribute_labels = pd.read_csv(
+ os.path.join('/home/norrenbr/tmp/Datasets/CUB200', 'CUB_200_2011', "attributes",
+ 'image_attribute_labels.txt'),
+ sep=' ', names=['img_id', 'attribute', "is_present", "certainty", "time"], on_bad_lines="skip")
+ train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ merged = image_attribute_labels.merge(train_test_split, on="img_id")
+ filtered_data = merged[merged["is_training_img"] == train]
+ return filtered_data
+
+
+ @staticmethod
+ def filter_attribute_labels(labels, min_certainty=3):
+ is_invisible_present = labels[labels["certainty"] == 1]["is_present"].sum()
+ if is_invisible_present != 0:
+ raise ValueError("Invisible present")
+ labels["img_id"] -= min(labels["img_id"])
+ labels["img_id"] = fillholes_in_array(labels["img_id"])
+ labels[labels["certainty"] == 1]["certainty"] = 4
+ labels = labels[labels["certainty"] >= min_certainty]
+ labels["attribute"] -= min(labels["attribute"])
+ labels = labels[["img_id", "attribute", "is_present"]]
+ labels["is_present"] = labels["is_present"].astype(bool)
+ return labels
+
+
+
+def fillholes_in_array(array):
+ unique_values = np.unique(array)
+ mapping = {x: i for i, x in enumerate(unique_values)}
+ array = array.map(mapping)
+ return array
diff --git a/dataset_classes/stanfordcars.py b/dataset_classes/stanfordcars.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be682a5d164a8b39cff5bd9cca82cc8cf5ebe53
--- /dev/null
+++ b/dataset_classes/stanfordcars.py
@@ -0,0 +1,121 @@
+import pathlib
+from typing import Callable, Optional, Any, Tuple
+
+import numpy as np
+import pandas as pd
+from PIL import Image
+from torchvision.datasets import VisionDataset
+from torchvision.datasets.utils import download_and_extract_archive, download_url
+
+
+class StanfordCarsClass(VisionDataset):
+ """`Stanford Cars `_ Dataset
+
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+ split into 8,144 training images and 8,041 testing images, where each class
+ has been split roughly in a 50-50 split
+
+ .. note::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Args:
+ root (string): Root directory of dataset
+ split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again."""
+ root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars"
+ def __init__(
+ self,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = True,
+ ) -> None:
+
+ try:
+ import scipy.io as sio
+ except ImportError:
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+ super().__init__(self.root, transform=transform, target_transform=target_transform)
+
+ self.train = train
+ self._base_folder = pathlib.Path(self.root) / "stanford_cars"
+ devkit = self._base_folder / "devkit"
+
+ if train:
+ self._annotations_mat_path = devkit / "cars_train_annos.mat"
+ self._images_base_path = self._base_folder / "cars_train"
+ else:
+ self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+ self._images_base_path = self._base_folder / "cars_test"
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self.samples = [
+ (
+ str(self._images_base_path / annotation["fname"]),
+ annotation["class"] - 1, # Original target mapping starts from 1, hence -1
+ )
+ for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+ ]
+ self.targets = np.array([x[1] for x in self.samples])
+ self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+ """Returns pil_image and class_id for given index"""
+ image_path, target = self.samples[idx]
+ pil_image = Image.open(image_path).convert("RGB")
+
+ if self.transform is not None:
+ pil_image = self.transform(pil_image)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return pil_image, target
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
+ download_root=str(self._base_folder),
+ md5="c3b158d763b6e2245038c8ad08e45376",
+ )
+ if self.train:
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
+ download_root=str(self._base_folder),
+ md5="065e5b463ae28d29e77c1b4b166cfe61",
+ )
+ else:
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
+ download_root=str(self._base_folder),
+ md5="4ce7ebf6a94d07f1952d94dd34c4d501",
+ )
+ download_url(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
+ root=str(self._base_folder),
+ md5="b0a2b23655a3edd16d84508592a98d10",
+ )
+
+ def _check_exists(self) -> bool:
+ if not (self._base_folder / "devkit").is_dir():
+ return False
+
+ return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
diff --git a/dataset_classes/travelingbirds.py b/dataset_classes/travelingbirds.py
new file mode 100644
index 0000000000000000000000000000000000000000..551ce1fd46b9b84e572ea18f4adc6ecd73cea00d
--- /dev/null
+++ b/dataset_classes/travelingbirds.py
@@ -0,0 +1,59 @@
+# TravelingBirds dataset needs to be downloaded from https://worksheets.codalab.org/bundles/0x518829de2aa440c79cd9d75ef6669f27
+# as it comes from https://github.com/yewsiang/ConceptBottleneck
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from dataset_classes.cub200 import CUB200Class
+from dataset_classes.utils import index_list_with_sorting, mask_list
+
+
+class TravelingBirds(CUB200Class):
+ init_base_folder = 'CUB_fixed'
+ root = Path.home() / "tmp/Datasets/TravelingBirds"
+ crop_root = Path.home() / "tmp/Datasets/PPTravelingBirds"
+ def get_all_samples_dir(self, dir):
+
+ self.base_folder = os.path.join(self.init_base_folder, dir)
+ main_dir = Path(self.root) / self.init_base_folder / dir
+ return self.get_all_sample(main_dir)
+
+ def adapt_to_crop(self):
+ self.root = self.crop_root
+ folder_name = "train" if self.train else "test"
+ folder_name = folder_name + "_cropped"
+ self.base_folder = 'CUB_fixed/' + folder_name
+
+ def get_all_sample(self, dir):
+ answer = []
+ for i, sub_dir in enumerate(sorted(os.listdir(dir))):
+ class_dir = dir / sub_dir
+ for single_img in os.listdir(class_dir):
+ answer.append([Path(sub_dir) / single_img, i + 1])
+ return answer
+ def _load_metadata(self):
+ train_test_split = pd.read_csv(
+ os.path.join(Path(self.root).parent / "CUB200", 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+ data = pd.read_csv(
+ os.path.join(Path(self.root).parent / "CUB200", 'CUB_200_2011', 'images.txt'),
+ sep=' ', names=['img_id', "path"])
+ img_dict = {x[1]: x[0] for x in data.values}
+ # TravelingBirds has all train+test images in both folders, just with different backgrounds.
+ # They are separated by train_test_split of CUB200.
+ if self.train:
+ samples = self.get_all_samples_dir("train")
+ mask = train_test_split["is_training_img"] == 1
+ else:
+ samples = self.get_all_samples_dir("test")
+ mask = train_test_split["is_training_img"] == 0
+ ids = np.array([img_dict[str(x[0])] for x in samples])
+ sorted = np.argsort(ids)
+ samples = index_list_with_sorting(samples, sorted)
+ samples = mask_list(samples, mask)
+ filepaths = [x[0] for x in samples]
+ labels = [x[1] for x in samples]
+ samples = pd.DataFrame({"filepath": filepaths, "target": labels})
+ self.data = samples
diff --git a/dataset_classes/utils.py b/dataset_classes/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0039ba93d0966230da88f2568d02bb7cebeebf
--- /dev/null
+++ b/dataset_classes/utils.py
@@ -0,0 +1,16 @@
+def index_list_with_sorting(list_to_sort, sorting_list):
+ answer = []
+ for entry in sorting_list:
+ answer.append(list_to_sort[entry])
+ return answer
+
+
+def mask_list(list_input, mask):
+ return [x for i, x in enumerate(list_input) if mask[i]]
+
+
+def txt_load(filename):
+ with open(filename, 'r') as f:
+ data = f.read()
+ return data
+
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e4e2f7b3680115f3e38c80511baede60fda0db03
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,117 @@
+name: QSENNEnv
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - brotli-python=1.0.9=py310h6a678d5_7
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2023.12.12=h06a4308_0
+ - certifi=2023.11.17=py310h06a4308_0
+ - cffi=1.16.0=py310h5eee18b_0
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - cryptography=41.0.7=py310hdda0065_0
+ - cuda-cudart=12.1.105=0
+ - cuda-cupti=12.1.105=0
+ - cuda-libraries=12.1.0=0
+ - cuda-nvrtc=12.1.105=0
+ - cuda-nvtx=12.1.105=0
+ - cuda-opencl=12.3.101=0
+ - cuda-runtime=12.1.0=0
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.13.1=py310h06a4308_0
+ - freetype=2.12.1=h4a9f257_0
+ - giflib=5.2.1=h5eee18b_3
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py310heeb90bb_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.4=py310h06a4308_0
+ - intel-openmp=2023.1.0=hdb19cb5_46306
+ - jinja2=3.1.2=py310h06a4308_0
+ - jpeg=9e=h5eee18b_1
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - lerc=3.0=h295c915_0
+ - libcublas=12.1.0.26=0
+ - libcufft=11.0.2.4=0
+ - libcufile=1.8.1.2=0
+ - libcurand=10.3.4.107=0
+ - libcusolver=11.4.4.55=0
+ - libcusparse=12.0.2.55=0
+ - libdeflate=1.17=h5eee18b_1
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.4=h5eee18b_0
+ - libjpeg-turbo=2.0.0=h9bf148f_0
+ - libnpp=12.0.2.50=0
+ - libnvjitlink=12.1.105=0
+ - libnvjpeg=12.1.1.14=0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.19.0=h5eee18b_0
+ - libtiff=4.5.1=h6a678d5_0
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.41.5=h5eee18b_0
+ - libwebp=1.3.2=h11a3e52_0
+ - libwebp-base=1.3.2=h5eee18b_0
+ - llvm-openmp=14.0.6=h9e868ea_0
+ - lz4-c=1.9.4=h6a678d5_0
+ - markupsafe=2.1.3=py310h5eee18b_0
+ - mkl=2023.1.0=h213fc3f_46344
+ - mkl-service=2.4.0=py310h5eee18b_1
+ - mkl_fft=1.3.8=py310h5eee18b_0
+ - mkl_random=1.2.4=py310hdb19cb5_0
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.3.0=py310h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=3.1=py310h06a4308_0
+ - numpy=1.26.3=py310h5f9d8c6_0
+ - numpy-base=1.26.3=py310hb5e798b_0
+ - openh264=2.1.1=h4ff587b_0
+ - openjpeg=2.4.0=h3ad879b_0
+ - openssl=3.0.12=h7f8727e_0
+ - pillow=10.0.1=py310ha6cbd5a_0
+ - pip=23.3.1=py310h06a4308_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyopenssl=23.2.0=py310h06a4308_0
+ - pysocks=1.7.1=py310h06a4308_0
+ - python=3.10.13=h955ad1f_0
+ - pytorch=2.1.2=py3.10_cuda12.1_cudnn8.9.2_0
+ - pytorch-cuda=12.1=ha16c6d3_5
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0.1=py310h5eee18b_0
+ - readline=8.2=h5eee18b_0
+ - requests=2.31.0=py310h06a4308_0
+ - setuptools=68.2.2=py310h06a4308_0
+ - sqlite=3.41.2=h5eee18b_0
+ - sympy=1.12=py310h06a4308_0
+ - tbb=2021.8.0=hdb19cb5_0
+ - tk=8.6.12=h1ccaba5_0
+ - torchaudio=2.1.2=py310_cu121
+ - torchtriton=2.1.0=py310
+ - torchvision=0.16.2=py310_cu121
+ - typing_extensions=4.7.1=py310h06a4308_0
+ - urllib3=1.26.18=py310h06a4308_0
+ - wheel=0.41.2=py310h06a4308_0
+ - xz=5.4.5=h5eee18b_0
+ - yaml=0.2.5=h7b6447c_0
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.5.5=hc292b87_0
+ - pip:
+ - fsspec==2023.12.2
+ - glm-saga==0.1.2
+ - pandas==2.1.4
+ - python-dateutil==2.8.2
+ - pytz==2023.3.post1
+ - six==1.16.0
+ - tqdm==4.66.1
+ - tzdata==2023.4
+prefix: /home/norrenbr/anaconda/tmp/envs/QSENN-Minimal
diff --git a/evaluation/Metrics/Dependence.py b/evaluation/Metrics/Dependence.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f1b26dfc19de0430925e38aac45ebcc33a94455
--- /dev/null
+++ b/evaluation/Metrics/Dependence.py
@@ -0,0 +1,21 @@
+import torch
+
+
+def compute_contribution_top_feature(features, outputs, weights, labels):
+ with torch.no_grad():
+ total_pre_softmax, predicted_classes = torch.max(outputs, dim=1)
+ feature_part = features * weights.to(features.device)[predicted_classes]
+ class_specific_feature_part = torch.zeros((weights.shape[0], features.shape[1],))
+ feature_class_part = torch.zeros((weights.shape[0], features.shape[1],))
+ for unique_class in predicted_classes.unique():
+ mask = predicted_classes == unique_class
+ class_specific_feature_part[unique_class] = feature_part[mask].mean(dim=0)
+ gt_mask = labels == unique_class
+ feature_class_part[unique_class] = feature_part[gt_mask].mean(dim=0)
+ abs_features = feature_part.abs()
+ abs_sum = abs_features.sum(dim=1)
+ fractions_abs = abs_features / abs_sum[:, None]
+ abs_max = fractions_abs.max(dim=1)[0]
+ mask = ~torch.isnan(abs_max)
+ abs_max = abs_max[mask]
+ return abs_max.mean()
\ No newline at end of file
diff --git a/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc b/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d59c7cc91e9d0c3e84e533ca3876e3ee9850c52
Binary files /dev/null and b/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc differ
diff --git a/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc b/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0de63438f9c10a07e3524cca02c53276c1d7622
Binary files /dev/null and b/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc differ
diff --git a/evaluation/Metrics/cub_Alignment.py b/evaluation/Metrics/cub_Alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4b41e427668f86ec530baab1796ac9d0678489
--- /dev/null
+++ b/evaluation/Metrics/cub_Alignment.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+from dataset_classes.cub200 import CUB200Class
+
+
+def get_cub_alignment_from_features(features_train_sorted):
+ metric_matrix = compute_metric_matrix(np.array(features_train_sorted), "train")
+ return np.mean(np.max(metric_matrix, axis=1))
+ pass
+
+
+def compute_metric_matrix(features, mode):
+ image_attribute_labels = CUB200Class.get_image_attribute_labels(train=mode == "train")
+ image_attribute_labels = CUB200Class.filter_attribute_labels(image_attribute_labels)
+ matrix_shape = (
+ features.shape[1], max(image_attribute_labels["attribute"]) + 1)
+ accuracy_matrix = np.zeros(matrix_shape)
+ sensitivity_matrix = np.zeros_like(accuracy_matrix)
+ grouped_attributes = image_attribute_labels.groupby("attribute")
+ for attribute_id, group in grouped_attributes:
+ is_present = group[group["is_present"]]
+ not_present = group[~group["is_present"]]
+ is_present_avg = np.mean(features[is_present["img_id"]], axis=0)
+ not_present_avg = np.mean(features[not_present["img_id"]], axis=0)
+ sensitivity_matrix[:, attribute_id] = not_present_avg
+ accuracy_matrix[:, attribute_id] = is_present_avg
+ metric_matrix = accuracy_matrix - sensitivity_matrix
+ no_abs_features = features - np.min(features, axis=0)
+ no_abs_feature_mean = metric_matrix / no_abs_features.mean(axis=0)[:, None]
+ return no_abs_feature_mean
diff --git a/evaluation/__pycache__/diversity.cpython-310.pyc b/evaluation/__pycache__/diversity.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d65f89a660c7660f26ce7578e34557c87b970b66
Binary files /dev/null and b/evaluation/__pycache__/diversity.cpython-310.pyc differ
diff --git a/evaluation/__pycache__/helpers.cpython-310.pyc b/evaluation/__pycache__/helpers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f7034adbe8c30d420975414880bb581f2052080
Binary files /dev/null and b/evaluation/__pycache__/helpers.cpython-310.pyc differ
diff --git a/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc b/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9460b6362b47ba4ddf699b74707b87b5a063ce73
Binary files /dev/null and b/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc differ
diff --git a/evaluation/__pycache__/utils.cpython-310.pyc b/evaluation/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e70bce079f72b6f83344a1decd675b1312655b6
Binary files /dev/null and b/evaluation/__pycache__/utils.cpython-310.pyc differ
diff --git a/evaluation/diversity.py b/evaluation/diversity.py
new file mode 100644
index 0000000000000000000000000000000000000000..033679ce9cf4546b74b0d1d4bdb6b8590c5c8865
--- /dev/null
+++ b/evaluation/diversity.py
@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+
+from evaluation.helpers import softmax_feature_maps
+
+
+class MultiKCrossChannelMaxPooledSum:
+ def __init__(self, top_k_range, weights, interactions, func="softmax"):
+ self.top_k_range = top_k_range
+ self.weights = weights
+ self.failed = False
+ self.max_ks = self.get_max_ks(weights)
+ self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device)
+ self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device)
+ self.ns_k = torch.zeros(len(top_k_range), device=weights.device)
+ self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device)
+ self.interactions = interactions
+ self.func = func
+
+ def get_max_ks(self, weights):
+ nonzeros = torch.count_nonzero(torch.tensor(weights), 1)
+ return nonzeros
+
+ def get_top_n_locality(self, outputs, initial_feature_maps, k):
+ feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs,
+ initial_feature_maps)
+ max_ks = self.max_ks[top_classes]
+ max_k_based_row_selection = max_ks >= k
+
+ result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps,
+ separated=True)
+ return result
+
+ def get_locality(self, outputs, initial_feature_maps, n):
+ answer = self.get_top_n_locality(outputs, initial_feature_maps, n)
+ return answer
+
+ def get_result(self):
+ # if torch.sum(self.exclusive_ns) ==0:
+ # end_idx = len(self.exclusive_ns) - 1
+ # else:
+
+ exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features)
+ local_array = torch.zeros_like(self.locality_of_used_features)
+ # if self.failed:
+ # return local_array, exclusive_array
+ cumulated = torch.cumsum(self.exclusive_ns, 0)
+ end_idx = torch.argmax(cumulated)
+ exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1]
+ exclusivity_array[exclusivity_array != exclusivity_array] = 0
+ exclusive_array[:len(exclusivity_array)] = exclusivity_array
+ locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[
+ self.locality_of_used_features != 0]
+ local_array[:len(locality_array)] = locality_array
+ return local_array, exclusive_array
+
+ def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False):
+ relevant_indices = get_relevant_indices(relevant_weights, k)[mask]
+ # this should have size batch x k x featuremapsize squared]
+ indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size)
+ sub_feature_maps = torch.gather(feature_maps[mask], 1, indices)
+ # shape batch x featuremapsquared: For each "pixel" the highest value
+ cross_pooled = torch.max(sub_feature_maps, 1)[0]
+ if separated:
+ return torch.sum(cross_pooled, 1) / k
+ else:
+ ns = len(cross_pooled)
+ result = torch.sum(cross_pooled) / (k)
+ # should be batch x map size
+
+ return ns, result
+
+ def adapt_feature_maps(self, outputs, initial_feature_maps):
+ if self.func == "softmax":
+ feature_maps = softmax_feature_maps(initial_feature_maps)
+ feature_maps = torch.flatten(feature_maps, 2)
+ vector_size = feature_maps.shape[2]
+ top_classes = torch.argmax(outputs, dim=1)
+ relevant_weights = self.weights[top_classes]
+ if relevant_weights.shape[1] != feature_maps.shape[1]:
+ feature_maps = self.interactions.get_localized_features(initial_feature_maps)
+ feature_maps = softmax_feature_maps(feature_maps)
+ feature_maps = torch.flatten(feature_maps, 2)
+ return feature_maps, relevant_weights, vector_size, top_classes
+
+ def calculate_locality(self, outputs, initial_feature_maps):
+ feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs,
+ initial_feature_maps)
+ max_ks = self.max_ks[top_classes]
+ for k in self.top_k_range:
+ # relevant_k_s = max_ks[]
+ max_k_based_row_selection = max_ks >= k
+ if torch.sum(max_k_based_row_selection) == 0:
+ break
+
+ exclusive_k = max_ks == k
+ if torch.sum(exclusive_k) != 0:
+ ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps)
+ self.locality_of_exclusely_used_features[k - 1] += result
+ self.exclusive_ns[k - 1] += ns
+ ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps)
+ self.ns_k[k - 1] += ns
+ self.locality_of_used_features[k - 1] += result
+
+ def __call__(self, outputs, initial_feature_maps):
+ self.calculate_locality(outputs, initial_feature_maps)
+
+
+def get_relevant_indices(weights, top_k):
+ top_k = weights.topk(top_k)[1]
+ return top_k
\ No newline at end of file
diff --git a/evaluation/helpers.py b/evaluation/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4a9902103fe63df01994acb079127ab719c9f1
--- /dev/null
+++ b/evaluation/helpers.py
@@ -0,0 +1,6 @@
+import torch
+
+
+def softmax_feature_maps(x):
+ # done: verify that this applies softmax along first dimension
+ return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x)
\ No newline at end of file
diff --git a/evaluation/qsenn_metrics.py b/evaluation/qsenn_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb8f21b6f7dfe101c8e668e9c422c1d88ce8751
--- /dev/null
+++ b/evaluation/qsenn_metrics.py
@@ -0,0 +1,39 @@
+import numpy as np
+import torch
+
+from evaluation.Metrics.Dependence import compute_contribution_top_feature
+from evaluation.Metrics.cub_Alignment import get_cub_alignment_from_features
+from evaluation.diversity import MultiKCrossChannelMaxPooledSum
+from evaluation.utils import get_metrics_for_model
+
+
+def evaluateALLMetricsForComps(features_train, outputs_train, feature_maps_test,
+ outputs_test, linear_matrix, labels_train):
+ with torch.no_grad():
+ if len(features_train) < 7000: # recognize CUB and TravelingBirds
+ cub_alignment = get_cub_alignment_from_features(features_train)
+ else:
+ cub_alignment = 0
+ print("cub_alignment: ", cub_alignment)
+ localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), linear_matrix, None)
+ batch_size = 300
+ for i in range(np.floor(len(features_train) / batch_size).astype(int)):
+ localizer(outputs_test[i * batch_size:(i + 1) * batch_size].to("cuda"),
+ feature_maps_test[i * batch_size:(i + 1) * batch_size].to("cuda"))
+
+ locality, exlusive_locality = localizer.get_result()
+ diversity = locality[4]
+ print("diversity@5: ", diversity)
+ abs_frac_mean = compute_contribution_top_feature(
+ features_train,
+ outputs_train,
+ linear_matrix,
+ labels_train)
+ print("Dependence ", abs_frac_mean)
+ answer_dict = {"diversity": diversity.item(), "Dependence": abs_frac_mean.item(), "Alignment":cub_alignment}
+ return answer_dict
+
+def eval_model_on_all_qsenn_metrics(model, test_loader, train_loader):
+ return get_metrics_for_model(train_loader, test_loader, model, evaluateALLMetricsForComps)
+
+
diff --git a/evaluation/utils.py b/evaluation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1b679fc9dac88e2fb897d69c34d959c19b3101
--- /dev/null
+++ b/evaluation/utils.py
@@ -0,0 +1,57 @@
+import torch
+from tqdm import tqdm
+
+
+
+def get_metrics_for_model(train_loader, test_loader, model, metric_evaluator):
+ (features_train, feature_maps_train, outputs_train, features_test, feature_maps_test,
+ outputs_test, labels) = [], [], [], [], [], [], []
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model.eval()
+ model = model.to(device)
+ training_transforms = train_loader.dataset.transform
+ train_loader.dataset.transform = test_loader.dataset.transform # Use test transform for train
+ train_loader = torch.utils.data.DataLoader(train_loader.dataset, batch_size=100, shuffle=False) # Turn off shuffling
+ print("Going in get metrics")
+ linear_matrix = model.linear.weight
+ entries = torch.nonzero(linear_matrix)
+ rel_features = torch.unique(entries[:, 1])
+ with torch.no_grad():
+ iterator = tqdm(enumerate(train_loader), total=len(train_loader))
+ for batch_idx, (data, target) in iterator:
+ xs1 = data.to("cuda")
+ output, feature_maps, final_features = model(xs1, with_feature_maps=True, with_final_features=True,)
+ outputs_train.append(output.to("cpu"))
+ features_train.append(final_features.to("cpu"))
+ labels.append(target.to("cpu"))
+ total = 0
+ correct = 0
+ iterator = tqdm(enumerate(test_loader), total=len(test_loader))
+ for batch_idx, (data, target) in iterator:
+ xs1 = data.to("cuda")
+ output, feature_maps, final_features = model(xs1, with_feature_maps=True,
+ with_final_features=True, )
+ feature_maps_test.append(feature_maps[:, rel_features].to("cpu"))
+ outputs_test.append(output.to("cpu"))
+ total += target.size(0)
+ _, predicted = output.max(1)
+ correct += predicted.eq(target.to("cuda")).sum().item()
+ print("test accuracy: ", correct / total)
+ features_train = torch.cat(features_train)
+ outputs_train = torch.cat(outputs_train)
+ feature_maps_test = torch.cat(feature_maps_test)
+ outputs_test = torch.cat(outputs_test)
+ labels = torch.cat(labels)
+ linear_matrix = linear_matrix[:, rel_features]
+ print("Shape of linear matrix: ", linear_matrix.shape)
+ all_metrics_dict = metric_evaluator(features_train, outputs_train,
+ feature_maps_test,
+ outputs_test, linear_matrix, labels)
+ result_dict = {"Accuracy": correct / total, "NFfeatures": linear_matrix.shape[1],
+ "PerClass": torch.nonzero(linear_matrix).shape[0] / linear_matrix.shape[0],
+ }
+ result_dict.update(all_metrics_dict)
+ print(result_dict)
+ # Reset Train transforms
+ train_loader.dataset.transform = training_transforms
+ return result_dict
diff --git a/fig/AutoML4FAS_Logo.jpeg b/fig/AutoML4FAS_Logo.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..35d4066fa5cf5967553960097b57f80c2ac8c580
Binary files /dev/null and b/fig/AutoML4FAS_Logo.jpeg differ
diff --git a/fig/Bund.png b/fig/Bund.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c92a104515f9b3c61642f7cd3cc898163e5ef0e
Binary files /dev/null and b/fig/Bund.png differ
diff --git a/fig/LUH.png b/fig/LUH.png
new file mode 100644
index 0000000000000000000000000000000000000000..af168ab3e866e5c66c616b6a090ef9c4ac212e3b
Binary files /dev/null and b/fig/LUH.png differ
diff --git a/fig/birds.png b/fig/birds.png
new file mode 100644
index 0000000000000000000000000000000000000000..330ebdff52c39b989a5c0cd42e0a35fdbeb7c1ff
Binary files /dev/null and b/fig/birds.png differ
diff --git a/finetuning/map_function.py b/finetuning/map_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aa65c3fa6dee0dc55484bdaae3fb181786eed1b
--- /dev/null
+++ b/finetuning/map_function.py
@@ -0,0 +1,11 @@
+from finetuning.qsenn import finetune_qsenn
+from finetuning.sldd import finetune_sldd
+
+
+def finetune(key, model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule, per_class, n_features):
+ if key == 'sldd':
+ return finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,per_class, n_features)
+ elif key == 'qsenn':
+ return finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_features,per_class, )
+ else:
+ raise ValueError(f"Unknown Finetuning key: {key}")
\ No newline at end of file
diff --git a/finetuning/qsenn.py b/finetuning/qsenn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce4dc8b65e6c703e51fe602c2ac897c97844897c
--- /dev/null
+++ b/finetuning/qsenn.py
@@ -0,0 +1,30 @@
+import os
+
+import torch
+
+from finetuning.utils import train_n_epochs
+from sparsification.qsenn import compute_qsenn_feature_selection_and_assignment
+
+
+def finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule ,n_features, n_per_class):
+ for iteration_epoch in range(4):
+ print(f"Starting iteration epoch {iteration_epoch}")
+ this_log_dir = log_dir / f"iteration_epoch_{iteration_epoch}"
+ this_log_dir.mkdir(parents=True, exist_ok=True)
+ feature_sel, sparse_layer,bias_sparse, current_mean, current_std = compute_qsenn_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ this_log_dir, n_classes, seed, n_features, n_per_class)
+ model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
+ if os.path.exists(this_log_dir / "trained_model.pth"):
+ model.load_state_dict(torch.load(this_log_dir / "trained_model.pth"))
+ _ = optimization_schedule.get_params() # count up, to have get correct lr
+ continue
+
+ model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
+ torch.save(model.state_dict(), this_log_dir / "trained_model.pth")
+ print(f"Finished iteration epoch {iteration_epoch}")
+ return model
+
+
+
+
diff --git a/finetuning/sldd.py b/finetuning/sldd.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8ac0034b14cbbf460f0bf59e25dfd8188ee94b
--- /dev/null
+++ b/finetuning/sldd.py
@@ -0,0 +1,22 @@
+import numpy as np
+import torch
+
+from FeatureDiversityLoss import FeatureDiversityLoss
+from finetuning.utils import train_n_epochs
+from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
+from sparsification.sldd import compute_sldd_feature_selection_and_assignment
+from train import train, test
+from training.optim import get_optimizer
+
+
+
+
+def finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_per_class, n_features, ):
+ feature_sel, weight, bias, mean, std = compute_sldd_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ log_dir, n_classes, seed,n_per_class, n_features)
+ model.set_model_sldd(feature_sel, weight, mean, std, bias)
+ model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader)
+ return model
+
+
diff --git a/finetuning/utils.py b/finetuning/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af751e2094c5ba6b2f83adadb5059f692329db37
--- /dev/null
+++ b/finetuning/utils.py
@@ -0,0 +1,14 @@
+from FeatureDiversityLoss import FeatureDiversityLoss
+from train import train, test
+from training.optim import get_optimizer
+
+
+def train_n_epochs(model, beta,optimization_schedule, train_loader, test_loader):
+ optimizer, schedule, epochs = get_optimizer(model, optimization_schedule)
+ fdl = FeatureDiversityLoss(beta, model.linear)
+ for epoch in range(epochs):
+ model = train(model, train_loader, optimizer, fdl, epoch)
+ schedule.step()
+ if epoch % 5 == 0 or epoch+1 == epochs:
+ test(model, test_loader, epoch)
+ return model
\ No newline at end of file
diff --git a/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg b/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6450f6174bdd37cace75c6b32a029bcfa8761ed7
Binary files /dev/null and b/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg differ
diff --git a/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg b/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f9f6063f1c6130694ddd53c0231b317abe9ef03b
Binary files /dev/null and b/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg differ
diff --git a/flagged/log.csv b/flagged/log.csv
new file mode 100644
index 0000000000000000000000000000000000000000..5af3d3f8c5830b52178c12538580c9cd038fd2e4
--- /dev/null
+++ b/flagged/log.csv
@@ -0,0 +1,3 @@
+input,output,flag,username,timestamp
+flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg,,,,2024-10-21 12:37:51.541901
+flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg,"[{""image"": ""flagged/output/e2f704607c002e0c557d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1b4541c3e93f034d746d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f8727dcfa3c59de0d873/image.webp"", ""caption"": null}, {""image"": ""flagged/output/c4b75e9fbc946f6ead6d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/5b5ad2dd997a635f4917/image.webp"", ""caption"": null}, {""image"": ""flagged/output/b066004e4a0114aa705b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/036072cdcc620de8cb65/image.webp"", ""caption"": null}, {""image"": ""flagged/output/218135cb251eb6cd0b2c/image.webp"", ""caption"": null}, {""image"": ""flagged/output/2a0671ba5ac1aa3bd2b9/image.webp"", ""caption"": null}, {""image"": ""flagged/output/595953adce3a654bbd33/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f333c69915509927b2ff/image.webp"", ""caption"": null}, {""image"": ""flagged/output/a966f50f23644e5046e8/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1a8a9e53fd4990fe5231/image.webp"", ""caption"": null}, {""image"": ""flagged/output/d7bc2f0eb8d70a562542/image.webp"", ""caption"": null}, {""image"": ""flagged/output/53fd53c5eab644d30338/image.webp"", ""caption"": null}, {""image"": ""flagged/output/ddf6b8ddc855838cc3b5/image.webp"", ""caption"": null}, {""image"": ""flagged/output/41a99b70366ac01533b4/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1b4ae8362917e14cb7a7/image.webp"", ""caption"": null}, {""image"": ""flagged/output/b321456290561eacf170/image.webp"", ""caption"": null}, {""image"": ""flagged/output/42d34c69c2384bda376b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/35d0e9ae554c0b863ef3/image.webp"", ""caption"": null}, {""image"": ""flagged/output/799f55238c434907570f/image.webp"", ""caption"": null}, {""image"": ""flagged/output/db82081afaabf2fb505b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/fff73f12467314dce395/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1bd17ff3896c5045b453/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e31f93405e1526fe3e55/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e9c9ff1da0805da0c0d8/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e6ef5ba2d6c65b3c1d21/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f763a51fb4a6d8a13313/image.webp"", ""caption"": null}, {""image"": ""flagged/output/7bdb4562631122e4ced7/image.webp"", ""caption"": null}, {""image"": ""flagged/output/9f7495b7c7648ecb1a10/image.webp"", ""caption"": null}, {""image"": ""flagged/output/ecbe75612f5db6cc7370/image.webp"", ""caption"": null}, {""image"": ""flagged/output/31f824d9522d30106a44/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e06b9103e0bf90cd398a/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1441b4f37340c2afa3d0/image.webp"", ""caption"": null}]",,,2024-10-21 23:01:32.158338
diff --git a/flagged/output/036072cdcc620de8cb65/image.webp b/flagged/output/036072cdcc620de8cb65/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4e7a831b3e63d41cf7dc53178e8f19231f456648
Binary files /dev/null and b/flagged/output/036072cdcc620de8cb65/image.webp differ
diff --git a/flagged/output/1441b4f37340c2afa3d0/image.webp b/flagged/output/1441b4f37340c2afa3d0/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..46e20fcfccd3a763d3eae21a0fda7d2908c6f53b
Binary files /dev/null and b/flagged/output/1441b4f37340c2afa3d0/image.webp differ
diff --git a/flagged/output/1a8a9e53fd4990fe5231/image.webp b/flagged/output/1a8a9e53fd4990fe5231/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..30452bb9f913012c3a787e78f5af2a657bfc4a82
Binary files /dev/null and b/flagged/output/1a8a9e53fd4990fe5231/image.webp differ
diff --git a/flagged/output/1b4541c3e93f034d746d/image.webp b/flagged/output/1b4541c3e93f034d746d/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..41fb284622b8bf0e85dac87a497a4942011579f2
Binary files /dev/null and b/flagged/output/1b4541c3e93f034d746d/image.webp differ
diff --git a/flagged/output/1b4ae8362917e14cb7a7/image.webp b/flagged/output/1b4ae8362917e14cb7a7/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..3566fc9c4f4f8bc2d8be57ffbaf1fb0b84f6fed8
Binary files /dev/null and b/flagged/output/1b4ae8362917e14cb7a7/image.webp differ
diff --git a/flagged/output/1bd17ff3896c5045b453/image.webp b/flagged/output/1bd17ff3896c5045b453/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a2e8a49694c6a233177b8757e916860ec2c217cb
Binary files /dev/null and b/flagged/output/1bd17ff3896c5045b453/image.webp differ
diff --git a/flagged/output/218135cb251eb6cd0b2c/image.webp b/flagged/output/218135cb251eb6cd0b2c/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..986c085197db498a852f013f503db78b64b4f7c5
Binary files /dev/null and b/flagged/output/218135cb251eb6cd0b2c/image.webp differ
diff --git a/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp b/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5e9a54c48df42fa656f0bced1d9580acd75cf7ba
Binary files /dev/null and b/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp differ
diff --git a/flagged/output/31f824d9522d30106a44/image.webp b/flagged/output/31f824d9522d30106a44/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c91c7a09d9b8611a1430afa699da601d7d0efe21
Binary files /dev/null and b/flagged/output/31f824d9522d30106a44/image.webp differ
diff --git a/flagged/output/35d0e9ae554c0b863ef3/image.webp b/flagged/output/35d0e9ae554c0b863ef3/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2473cae43807f063aa4d3e568e06e17e4b569920
Binary files /dev/null and b/flagged/output/35d0e9ae554c0b863ef3/image.webp differ
diff --git a/flagged/output/41a99b70366ac01533b4/image.webp b/flagged/output/41a99b70366ac01533b4/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4121b433b87b66dd3fbb58722c67818906c67411
Binary files /dev/null and b/flagged/output/41a99b70366ac01533b4/image.webp differ
diff --git a/flagged/output/42d34c69c2384bda376b/image.webp b/flagged/output/42d34c69c2384bda376b/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7618c903c18dc2451d25e1f32656f4caf9fe6ddb
Binary files /dev/null and b/flagged/output/42d34c69c2384bda376b/image.webp differ
diff --git a/flagged/output/53fd53c5eab644d30338/image.webp b/flagged/output/53fd53c5eab644d30338/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..2abbb10f1dbe66b93e37422b2470a0f071dea7cf
Binary files /dev/null and b/flagged/output/53fd53c5eab644d30338/image.webp differ
diff --git a/flagged/output/595953adce3a654bbd33/image.webp b/flagged/output/595953adce3a654bbd33/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..881add82b80c09007934e0467acd081e1b5fd7ac
Binary files /dev/null and b/flagged/output/595953adce3a654bbd33/image.webp differ
diff --git a/flagged/output/5b5ad2dd997a635f4917/image.webp b/flagged/output/5b5ad2dd997a635f4917/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e57d0e88fbfc7af54aeb69995fe44af657c0d8dd
Binary files /dev/null and b/flagged/output/5b5ad2dd997a635f4917/image.webp differ
diff --git a/flagged/output/799f55238c434907570f/image.webp b/flagged/output/799f55238c434907570f/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..85d8a8fe108f97bec9684ccb2c614db43035d88e
Binary files /dev/null and b/flagged/output/799f55238c434907570f/image.webp differ
diff --git a/flagged/output/7bdb4562631122e4ced7/image.webp b/flagged/output/7bdb4562631122e4ced7/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0b046a9a2ca40f025b7cc77df1b4c4f0613a7659
Binary files /dev/null and b/flagged/output/7bdb4562631122e4ced7/image.webp differ
diff --git a/flagged/output/9f7495b7c7648ecb1a10/image.webp b/flagged/output/9f7495b7c7648ecb1a10/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..bc21593a6869f1cc00f78de4dd9ebf912d18d795
Binary files /dev/null and b/flagged/output/9f7495b7c7648ecb1a10/image.webp differ
diff --git a/flagged/output/a966f50f23644e5046e8/image.webp b/flagged/output/a966f50f23644e5046e8/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ffb81c67f03b993798f710e71e65b0f43cd151ca
Binary files /dev/null and b/flagged/output/a966f50f23644e5046e8/image.webp differ
diff --git a/flagged/output/b066004e4a0114aa705b/image.webp b/flagged/output/b066004e4a0114aa705b/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b861d88dda4c0c7b783a87abaabc29f94dc943b2
Binary files /dev/null and b/flagged/output/b066004e4a0114aa705b/image.webp differ
diff --git a/flagged/output/b321456290561eacf170/image.webp b/flagged/output/b321456290561eacf170/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a10280498c93346105b7da59ba0808494004024c
Binary files /dev/null and b/flagged/output/b321456290561eacf170/image.webp differ
diff --git a/flagged/output/c4b75e9fbc946f6ead6d/image.webp b/flagged/output/c4b75e9fbc946f6ead6d/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..106535e80842768f14da47245baa981cabeea71b
Binary files /dev/null and b/flagged/output/c4b75e9fbc946f6ead6d/image.webp differ
diff --git a/flagged/output/d7bc2f0eb8d70a562542/image.webp b/flagged/output/d7bc2f0eb8d70a562542/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..686f7aac20f6a5449db99a97dc43c01dfdd99551
Binary files /dev/null and b/flagged/output/d7bc2f0eb8d70a562542/image.webp differ
diff --git a/flagged/output/db82081afaabf2fb505b/image.webp b/flagged/output/db82081afaabf2fb505b/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..af93f3eacee5ef0a4903995aaf8d2e2e5921976d
Binary files /dev/null and b/flagged/output/db82081afaabf2fb505b/image.webp differ
diff --git a/flagged/output/ddf6b8ddc855838cc3b5/image.webp b/flagged/output/ddf6b8ddc855838cc3b5/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..f5d206e37f97d1f45b91aceb45df630ca9fae223
Binary files /dev/null and b/flagged/output/ddf6b8ddc855838cc3b5/image.webp differ
diff --git a/flagged/output/e06b9103e0bf90cd398a/image.webp b/flagged/output/e06b9103e0bf90cd398a/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..8b4510986e28e5a132f4cc197ce9063b072b113b
Binary files /dev/null and b/flagged/output/e06b9103e0bf90cd398a/image.webp differ
diff --git a/flagged/output/e2f704607c002e0c557d/image.webp b/flagged/output/e2f704607c002e0c557d/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a7b2dafc246639799cfabd97306b3c4ba426cba6
Binary files /dev/null and b/flagged/output/e2f704607c002e0c557d/image.webp differ
diff --git a/flagged/output/e31f93405e1526fe3e55/image.webp b/flagged/output/e31f93405e1526fe3e55/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..955afaa8ead0e0bc67be7722bbc791dbfe4f35be
Binary files /dev/null and b/flagged/output/e31f93405e1526fe3e55/image.webp differ
diff --git a/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp b/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..4e0429c00817f88878d5cfc460039b8ed169c74c
Binary files /dev/null and b/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp differ
diff --git a/flagged/output/e9c9ff1da0805da0c0d8/image.webp b/flagged/output/e9c9ff1da0805da0c0d8/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..d73c97deb0388bfb5423ad36e686e5e3ca44ce8d
Binary files /dev/null and b/flagged/output/e9c9ff1da0805da0c0d8/image.webp differ
diff --git a/flagged/output/ecbe75612f5db6cc7370/image.webp b/flagged/output/ecbe75612f5db6cc7370/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7608061bb4d13a3e87696620071a61112463dea9
Binary files /dev/null and b/flagged/output/ecbe75612f5db6cc7370/image.webp differ
diff --git a/flagged/output/f333c69915509927b2ff/image.webp b/flagged/output/f333c69915509927b2ff/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..766635e3fd9996a671a9fd9e09bd37901f37a20e
Binary files /dev/null and b/flagged/output/f333c69915509927b2ff/image.webp differ
diff --git a/flagged/output/f763a51fb4a6d8a13313/image.webp b/flagged/output/f763a51fb4a6d8a13313/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..65b0c4ddd52dac9f92f898bc64ed67c18722c6ac
Binary files /dev/null and b/flagged/output/f763a51fb4a6d8a13313/image.webp differ
diff --git a/flagged/output/f8727dcfa3c59de0d873/image.webp b/flagged/output/f8727dcfa3c59de0d873/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..175792b0035f9df9db88d62f56d40356d8afbbfe
Binary files /dev/null and b/flagged/output/f8727dcfa3c59de0d873/image.webp differ
diff --git a/flagged/output/fff73f12467314dce395/image.webp b/flagged/output/fff73f12467314dce395/image.webp
new file mode 100644
index 0000000000000000000000000000000000000000..0b51dcf8c9e672660529ad1adea9afe37e5a4f08
Binary files /dev/null and b/flagged/output/fff73f12467314dce395/image.webp differ
diff --git a/get_data.py b/get_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e9414c933a64f1124f4eda6ec0faa8cd8ed2ee
--- /dev/null
+++ b/get_data.py
@@ -0,0 +1,119 @@
+from pathlib import Path
+
+import torch
+import torchvision
+from torchvision.transforms import transforms, TrivialAugmentWide
+
+from configs.dataset_params import normalize_params
+from dataset_classes.cub200 import CUB200Class
+from dataset_classes.stanfordcars import StanfordCarsClass
+from dataset_classes.travelingbirds import TravelingBirds
+
+
+def get_data(dataset, crop = True, img_size=448):
+ batchsize = 16
+ if dataset == "CUB2011":
+ train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"])
+ test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"])
+ train_dataset = CUB200Class(True, train_transform, crop)
+ test_dataset = CUB200Class(False, test_transform, crop)
+ elif dataset == "TravelingBirds":
+ train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"])
+ test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"])
+ train_dataset = TravelingBirds(True, train_transform, crop)
+ test_dataset = TravelingBirds(False, test_transform, crop)
+
+ elif dataset == "StanfordCars":
+ train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"])
+ test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"])
+ train_dataset = StanfordCarsClass(True, train_transform)
+ test_dataset = StanfordCarsClass(False, test_transform)
+ elif dataset == "FGVCAircraft":
+ raise NotImplementedError
+
+ elif dataset == "ImageNet":
+ # Defaults from the robustness package
+ if img_size != 224:
+ raise NotImplementedError("ImageNet is setup to only work with 224x224 images")
+ train_transform = transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ColorJitter(
+ brightness=0.1,
+ contrast=0.1,
+ saturation=0.1
+ ),
+ transforms.ToTensor(),
+ Lighting(0.05, IMAGENET_PCA['eigval'],
+ IMAGENET_PCA['eigvec'])
+ ])
+ """
+ Standard training data augmentation for ImageNet-scale datasets: Random crop,
+ Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc)
+ """
+ test_transform = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ ])
+ imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet"
+ train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform)
+ test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform)
+ batchsize = 64
+
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
+ return train_loader, test_loader
+
+def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize):
+ augmentation = []
+ if random_center_crop:
+ augmentation.append(transforms.Resize(size))
+ else:
+ augmentation.append(transforms.Resize((size, size)))
+ if training:
+ if random_center_crop:
+ augmentation.append(transforms.RandomCrop(size, padding=4))
+ else:
+ if random_center_crop:
+ augmentation.append(transforms.CenterCrop(size))
+ if training:
+ if hflip:
+ augmentation.append(transforms.RandomHorizontalFlip())
+ if jitter:
+ augmentation.append(transforms.ColorJitter(jitter, jitter, jitter))
+ if trivialAug:
+ augmentation.append(TrivialAugmentWide())
+ augmentation.append(transforms.ToTensor())
+ augmentation.append(transforms.Normalize(**normalize))
+ return transforms.Compose(augmentation)
+
+class Lighting(object):
+ """
+ Lighting noise (see https://git.io/fhBOc)
+ """
+
+ def __init__(self, alphastd, eigval, eigvec):
+ self.alphastd = alphastd
+ self.eigval = eigval
+ self.eigvec = eigvec
+
+ def __call__(self, img):
+ if self.alphastd == 0:
+ return img
+
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
+ rgb = self.eigvec.type_as(img).clone() \
+ .mul(alpha.view(1, 3).expand(3, 3)) \
+ .mul(self.eigval.view(1, 3).expand(3, 3)) \
+ .sum(1).squeeze()
+
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
+IMAGENET_PCA = {
+ 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
+ 'eigvec': torch.Tensor([
+ [-0.5675, 0.7192, 0.4009],
+ [-0.5808, -0.0045, -0.8140],
+ [-0.5836, -0.6948, 0.4203],
+ ])
+}
diff --git a/load_model.py b/load_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e962a7497da9a895feed12708fe9c55a24dbdd
--- /dev/null
+++ b/load_model.py
@@ -0,0 +1,51 @@
+from argparse import ArgumentParser
+from pathlib import Path
+
+import torch
+
+from architectures.model_mapping import get_model
+from configs.dataset_params import dataset_constants
+from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics
+from get_data import get_data
+
+def extract_sel_mean_std_bias_assignemnt(state_dict):
+ feature_sel = state_dict["linear.selection"]
+ #feature_sel = selection
+ weight_at_selection = state_dict["linear.layer.weight"]
+ mean = state_dict["linear.mean"]
+ std = state_dict["linear.std"]
+ bias = state_dict["linear.layer.bias"]
+ return feature_sel, weight_at_selection, mean, std, bias
+
+
+def eval_model(dataset, arch,seed=123456, model_type="qsenn",crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None):
+ n_classes = dataset_constants[dataset]["num_classes"]
+ train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size)
+ model = get_model(arch, n_classes, reduced_strides)
+ if folder is None:
+ folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/"
+ print(folder)
+ state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
+ selection= torch.load(folder / f"SlDD_Selection_50.pt")
+ state_dict['linear.selection']=selection
+ print(state_dict.keys())
+ feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
+ model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
+ model.load_state_dict(state_dict)
+ print(model)
+ metrics_finetuned = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader)
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"])
+ parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"])
+ parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"])
+ parser.add_argument('--seed', default=123456, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629
+ parser.add_argument('--cropGT', default=False, type=bool,
+ help='Whether to crop CUB/TravelingBirds based on GT Boundaries')
+ parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567
+ parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class')
+ parser.add_argument('--img_size', default=448, type=int, help='Image size')
+ parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets')
+ args = parser.parse_args()
+ eval_model(args.dataset, args.arch, args.seed, args.model_type,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides)
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a340862967c2c1d8befc1eff79bd00122223f93
--- /dev/null
+++ b/main.py
@@ -0,0 +1,79 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import numpy as np
+import torch
+from tqdm import trange
+
+from FeatureDiversityLoss import FeatureDiversityLoss
+from architectures.model_mapping import get_model
+from configs.architecture_params import architecture_params
+from configs.dataset_params import dataset_constants
+from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics
+from finetuning.map_function import finetune
+from get_data import get_data
+from saving.logging import Tee
+from saving.utils import json_save
+from train import train, test
+from training.optim import get_optimizer, get_scheduler_for_model
+
+
+def main(dataset, arch,seed=None, model_type="qsenn", do_dense=True,crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False):
+ # create random seed, if seed is None
+ if seed is None:
+ seed = np.random.randint(0, 1000000)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ dataset_key = dataset
+ if crop:
+ assert dataset in ["CUB2011","TravelingBirds"]
+ dataset_key += "_crop"
+ log_dir = Path.home()/f"tmp/{arch}/{dataset_key}/{seed}/"
+ log_dir.mkdir(parents=True, exist_ok=True)
+ tee = Tee(log_dir / "log.txt") # save log to file
+ n_classes = dataset_constants[dataset]["num_classes"]
+ train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size)
+ model = get_model(arch, n_classes, reduced_strides)
+ fdl = FeatureDiversityLoss(architecture_params[arch]["beta"], model.linear)
+ OptimizationSchedule = get_scheduler_for_model(model_type, dataset)
+ optimizer, schedule, dense_epochs =get_optimizer(model, OptimizationSchedule)
+ if not os.path.exists(log_dir / "Trained_DenseModel.pth"):
+ if do_dense:
+ for epoch in trange(dense_epochs):
+ model = train(model, train_loader, optimizer, fdl, epoch)
+ schedule.step()
+ if epoch % 5 == 0:
+ test(model, test_loader,epoch)
+ else:
+ print("Using pretrained model, only makes sense for ImageNet")
+ torch.save(model.state_dict(), os.path.join(log_dir, f"Trained_DenseModel.pth"))
+ else:
+ model.load_state_dict(torch.load(log_dir / "Trained_DenseModel.pth"))
+ if not os.path.exists( log_dir/f"Results_DenseModel.json"):
+ metrics_dense = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader)
+ json_save(os.path.join(log_dir, f"Results_DenseModel.json"), metrics_dense)
+ final_model = finetune(model_type, model, train_loader, test_loader, log_dir, n_classes, seed, architecture_params[arch]["beta"], OptimizationSchedule, n_per_class, n_features)
+ torch.save(final_model.state_dict(), os.path.join(log_dir,f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth"))
+ metrics_finetuned = eval_model_on_all_qsenn_metrics(final_model, test_loader, train_loader)
+ json_save(os.path.join(log_dir, f"Results_{model_type}_{n_features}_{n_per_class}_FinetunedModel.json"), metrics_finetuned)
+ print("Done")
+ pass
+
+
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"])
+ parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"])
+ parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"])
+ parser.add_argument('--seed', default=None, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629
+ parser.add_argument('--do_dense', default=True, type=bool, help='whether to train dense model. Should be true for all datasets except (maybe) ImageNet')
+ parser.add_argument('--cropGT', default=False, type=bool,
+ help='Whether to crop CUB/TravelingBirds based on GT Boundaries')
+ parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567
+ parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class')
+ parser.add_argument('--img_size', default=448, type=int, help='Image size')
+ parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets')
+ args = parser.parse_args()
+ main(args.dataset, args.arch, args.seed, args.model_type, args.do_dense,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f8641a287620411061756d1e997027170cd2a33
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+torch
+torchvision
+opencv-python
\ No newline at end of file
diff --git a/saving/logging.py b/saving/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..377e31b33e06865bb588bfec32678c947e5c3bb3
--- /dev/null
+++ b/saving/logging.py
@@ -0,0 +1,27 @@
+import sys
+
+
+class Tee(object):
+ def __init__(self, name, file_only=False):
+ self.file = open(name, "a")
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+ sys.stdout = self
+ sys.stderr = self
+ self.file_only = file_only
+
+ def __del__(self):
+ sys.stdout = self.stdout
+ sys.stderr = self.stderr
+ self.file.close()
+
+ def write(self, data):
+ self.file.write(data)
+ if not self.file_only:
+ self.stdout.write(data)
+ self.flush()
+
+ def flush(self):
+ self.file.flush()
+
+
diff --git a/saving/utils.py b/saving/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bf35c3c5129e84ab6f13f85c4c04fba4f7e33b4
--- /dev/null
+++ b/saving/utils.py
@@ -0,0 +1,6 @@
+import json
+
+
+def json_save(filename, data):
+ with open(filename, "w") as f:
+ json.dump(data, f,indent=4)
\ No newline at end of file
diff --git a/sparsification/FeatureSelection.py b/sparsification/FeatureSelection.py
new file mode 100644
index 0000000000000000000000000000000000000000..885a7bc9fab8842e9eece0f07690636b1d623233
--- /dev/null
+++ b/sparsification/FeatureSelection.py
@@ -0,0 +1,473 @@
+from argparse import ArgumentParser
+import logging
+import math
+import os.path
+import sys
+import time
+import warnings
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader
+from torch import nn
+
+import torch as ch
+
+from sparsification.utils import safe_zip
+
+# TODO checkout this change: Marks changes to the group version of glmsaga
+
+"""
+This would need glm_saga to run
+usage to select 50 features with parameters as in paper:
+metadata contains information about the precomputed train features in feature_loaders
+args contains the default arguments for glm-saga, as described at the bottom
+def get_glm_to_zero(feature_loaders, metadata, args, num_classes, device, train_ds, Ntotal):
+ num_features = metadata["X"]["num_features"][0]
+ fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8,
+ 50,
+ True,0.1,
+ lookback=3, tol=1e-4,
+ epsilon=1,)
+ to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device)
+ return to_drop
+
+to_drop is then used to remove the features from the downstream fitting and finetuning.
+"""
+
+
+class FeatureSelectionFitting:
+ def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac,out_dir, lookback=None, tol=None,
+ epsilon=None):
+ """
+ This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks)
+ The function extended_mask_max covers the changed operator,
+ Args:
+ n_features:
+ n_classes:
+ args: default args for glmsaga
+ selalpha: alpha for elastic net
+ nKeep: target number features
+ lam_fac: discount factor for lambda
+ parameters of glmsaga
+ lookback:
+ tol:
+ epsilon:
+ """
+ self.selected_features = torch.zeros(n_features, dtype=torch.bool)
+ self.num_features = n_features
+ self.selalpha = selalpha
+ self.lam_Fac = lam_fac
+ self.out_dir = out_dir
+ self.n_classes = n_classes
+ self.nKeep = nKeep
+ self.args = self.extend_args(args, lookback, tol, epsilon)
+
+ # Extended Proximal Operator for Feature Selection
+ def extended_mask_max(self, greater_to_keep, thresh):
+ prev = greater_to_keep[self.selected_features]
+ greater_to_keep[self.selected_features] = torch.min(greater_to_keep)
+ max_entry = torch.argmax(greater_to_keep)
+ greater_to_keep[self.selected_features] = prev
+ mask = torch.zeros_like(greater_to_keep)
+ mask[max_entry] = 1
+ final_mask = (greater_to_keep > thresh)
+ final_mask = final_mask * mask
+ allowed_to_keep = torch.logical_or(self.selected_features, final_mask)
+ return allowed_to_keep
+
+ def extend_args(self, args, lookback, tol, epsilon):
+ for key, entry in safe_zip(["lookbehind", "tol",
+ "lr_decay_factor", ], [lookback, tol, epsilon]):
+ if entry is not None:
+ setattr(args, key, entry)
+ return args
+
+ # Grouped L1 regularization
+ # proximal operator for f(weight) = lam * \|weight\|_2
+ # where the 2-norm is taken columnwise
+ def group_threshold(self, weight, lam):
+ norm = weight.norm(p=2, dim=0) + 1e-6
+ # print(ch.sum((norm > lam)))
+ return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam)
+
+ # Elastic net regularization with group sparsity
+ # proximal operator for f(x) = alpha * \|x\|_1 + beta * \|x\|_2^2
+ # where the 2-norm is taken columnwise
+ def group_threshold_with_shrinkage(self, x, alpha, beta):
+ y = self.group_threshold(x, alpha)
+ return y / (1 + beta)
+
+ def threshold(self, weight_new, lr, lam):
+ alpha = self.selalpha
+ if alpha == 1:
+ # Pure L1 regularization
+ weight_new = self.group_threshold(weight_new, lr * lam * alpha)
+ else:
+ # Elastic net regularization
+ weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha,
+ lr * lam * (1 - alpha))
+ return weight_new
+
+ # Train an elastic GLM with proximal SAGA
+ # Since SAGA stores a scalar for each example-class pair, either pass
+ # the number of examples and number of classes or calculate it with an
+ # initial pass over the loaders
+ def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None,
+ state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4,
+ preprocess=None, lookbehind=None, family='multinomial', logger=None):
+ if logger is None:
+ logger = print
+ with ch.no_grad():
+ weight, bias = list(linear.parameters())
+ if table_device is None:
+ table_device = weight.device
+
+ # get total number of examples and initialize scalars
+ # for computing the gradients
+ if n_ex is None:
+ n_ex = sum(tensors[0].size(0) for tensors in loader)
+ if n_classes is None:
+ if family == 'multinomial':
+ n_classes = max(tensors[1].max().item() for tensors in loader) + 1
+ elif family == 'gaussian':
+ for batch in loader:
+ y = batch[1]
+ break
+ n_classes = y.size(1)
+
+ # Storage for scalar gradients and averages
+ if state is None:
+ a_table = ch.zeros(n_ex, n_classes).to(table_device)
+ w_grad_avg = ch.zeros_like(weight).to(weight.device)
+ b_grad_avg = ch.zeros_like(bias).to(weight.device)
+ else:
+ a_table = state["a_table"].to(table_device)
+ w_grad_avg = state["w_grad_avg"].to(weight.device)
+ b_grad_avg = state["b_grad_avg"].to(weight.device)
+
+ obj_history = []
+ obj_best = None
+ nni = 0
+ for t in range(nepochs):
+ total_loss = 0
+ for n_batch, batch in enumerate(loader):
+ if len(batch) == 3:
+ X, y, idx = batch
+ w = None
+ elif len(batch) == 4:
+ X, y, w, idx = batch
+ else:
+ raise ValueError(
+ f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}")
+
+ if preprocess is not None:
+ device = get_device(preprocess)
+ with ch.no_grad():
+ X = preprocess(X.to(device))
+ X = X.to(weight.device)
+ out = linear(X)
+
+ # split gradient on only the cross entropy term
+ # for efficient storage of gradient information
+ if family == 'multinomial':
+ if w is None:
+ loss = F.cross_entropy(out, y.to(weight.device), reduction='mean')
+ else:
+ loss = F.cross_entropy(out, y.to(weight.device), reduction='none')
+ loss = (loss * w).mean()
+ I = ch.eye(linear.weight.size(0))
+ target = I[y].to(weight.device) # change to OHE
+
+ # Calculate new scalar gradient
+ logits = F.softmax(linear(X))
+ elif family == 'gaussian':
+ if w is None:
+ loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean')
+ else:
+ loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none')
+ loss = (loss * (w.unsqueeze(1))).mean()
+ target = y
+
+ # Calculate new scalar gradient
+ logits = linear(X)
+ else:
+ raise ValueError(f"Unknown family: {family}")
+ total_loss += loss.item() * X.size(0)
+
+ # BS x NUM_CLASSES
+ a = logits - target
+ if w is not None:
+ a = a * w.unsqueeze(1)
+ a_prev = a_table[idx].to(weight.device)
+
+ # weight parameter
+ w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0)
+ w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0)
+ w_saga = w_grad - w_grad_prev + w_grad_avg
+ weight_new = weight - lr * w_saga
+ weight_new = self.threshold(weight_new, lr, lam)
+ # bias parameter
+ b_grad = a.mean(0)
+ b_grad_prev = a_prev.mean(0)
+ b_saga = b_grad - b_grad_prev + b_grad_avg
+ bias_new = bias - lr * b_saga
+
+ # update table and averages
+ a_table[idx] = a.to(table_device)
+ w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex)
+ b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex)
+
+ if lookbehind is None:
+ dw = (weight_new - weight).norm(p=2)
+ db = (bias_new - bias).norm(p=2)
+ criteria = ch.sqrt(dw ** 2 + db ** 2)
+
+ if criteria.item() <= tol:
+ return {
+ "a_table": a_table.cpu(),
+ "w_grad_avg": w_grad_avg.cpu(),
+ "b_grad_avg": b_grad_avg.cpu()
+ }
+
+ weight.data = weight_new
+ bias.data = bias_new
+
+ saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * (
+ weight ** 2).sum()
+
+ # save amount of improvement
+ obj_history.append(saga_obj.item())
+ if obj_best is None or saga_obj.item() + tol < obj_best:
+ obj_best = saga_obj.item()
+ nni = 0
+ else:
+ nni += 1
+
+ # Stop if no progress for lookbehind iterationsd:])
+ criteria = lookbehind is not None and (nni >= lookbehind)
+
+ nnz = (weight.abs() > 1e-5).sum().item()
+ total = weight.numel()
+ if verbose and (t % verbose) == 0:
+ if lookbehind is None:
+ logger(
+ f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}")
+ else:
+ logger(
+ f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}")
+
+ if lookbehind is not None and criteria:
+ logger(
+ f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]")
+ return {
+ "a_table": a_table.cpu(),
+ "w_grad_avg": w_grad_avg.cpu(),
+ "b_grad_avg": b_grad_avg.cpu()
+ }
+
+ logger(f"did not converge at {nepochs} iterations (criteria {criteria})")
+ return {
+ "a_table": a_table.cpu(),
+ "w_grad_avg": w_grad_avg.cpu(),
+ "b_grad_avg": b_grad_avg.cpu()
+ }
+
+ def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries,
+ table_device=None, preprocess=None, group=False,
+ verbose=None, state=None, n_ex=None, n_classes=None,
+ tol=1e-4, epsilon=0.001, k=100, checkpoint=None,
+ do_zero=True, lr_decay_factor=1, metadata=None,
+ val_loader=None, test_loader=None, lookbehind=None,
+ family='multinomial', encoder=None, tot_tries=1):
+ if encoder is not None:
+ warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning)
+ preprocess = encoder
+ device = get_device(linear)
+ checkpoint = self.out_dir
+ if preprocess is not None and (device != get_device(preprocess)):
+ raise ValueError(
+ f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})")
+
+ if metadata is not None:
+ if n_ex is None:
+ n_ex = metadata['X']['num_examples']
+ if n_classes is None:
+ n_classes = metadata['y']['num_classes']
+ lam_fac = (1 + (tries - 1) / tot_tries)
+ print("Using lam_fac ", lam_fac)
+ max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata,
+ family=family) / max(
+ 0.001, alpha) * lam_fac
+ group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata,
+ family=family) / max(
+ 0.001, alpha) * lam_fac
+ min_lam = epsilon * max_lam
+ group_min_lam = epsilon * group_lam
+ # logspace is base 10 but log is base e so use log10
+ lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k)
+ lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k)
+ found = False
+ if do_zero:
+ lams = ch.cat([lams, lams.new_zeros(1)])
+ lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]])
+
+ path = []
+ best_val_loss = float('inf')
+
+ if checkpoint is not None:
+ os.makedirs(checkpoint, exist_ok=True)
+
+ file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log'))
+ stdout_handler = logging.StreamHandler(sys.stdout)
+ handlers = [file_handler, stdout_handler]
+
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format='[%(asctime)s] %(levelname)s - %(message)s',
+ handlers=handlers
+ )
+ logger = logging.getLogger('glm_saga').info
+ else:
+ logger = print
+ while self.selected_features.sum() < self.nKeep: # TODO checkout this change, one iteration per feature
+ n_feature_to_keep = self.selected_features.sum()
+ for i, (lam, lr) in enumerate(zip(lams, lrs)):
+ lam = lam * self.lam_Fac
+ start_time = time.time()
+ self.selected_features = self.selected_features.to(device)
+ state = self.train_saga(linear, loader, lr, nepochs, lam, alpha,
+ table_device=table_device, preprocess=preprocess, group=group, verbose=verbose,
+ state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind,
+ family=family, logger=logger)
+
+ with ch.no_grad():
+ loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess,
+ family=family)
+ loss, acc = loss.item(), acc.item()
+
+ loss_val, acc_val = -1, -1
+ if val_loader:
+ loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha,
+ preprocess=preprocess,
+ family=family)
+ loss_val, acc_val = loss_val.item(), acc_val.item()
+
+ loss_test, acc_test = -1, -1
+ if test_loader:
+ loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha,
+ preprocess=preprocess, family=family)
+ loss_test, acc_test = loss_test.item(), acc_test.item()
+
+ params = {
+ "lam": lam,
+ "lr": lr,
+ "alpha": alpha,
+ "time": time.time() - start_time,
+ "loss": loss,
+ "metrics": {
+ "loss_tr": loss,
+ "acc_tr": acc,
+ "loss_val": loss_val,
+ "acc_val": acc_val,
+ "loss_test": loss_test,
+ "acc_test": acc_test,
+ },
+ "weight": linear.weight.detach().cpu().clone(),
+ "bias": linear.bias.detach().cpu().clone()
+
+ }
+ path.append(params)
+ if loss_val is not None and loss_val < best_val_loss:
+ best_val_loss = loss_val
+ best_params = params
+ found = True
+ nnz = (linear.weight.abs() > 1e-5).sum().item()
+ total = linear.weight.numel()
+ if family == 'multinomial':
+ logger(
+ f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}")
+ elif family == 'gaussian':
+ logger(
+ f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}")
+
+ if self.check_new_feature(linear.weight): # TODO checkout this change, canceling if new feature is used
+ if checkpoint is not None:
+ ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth"))
+ break
+ if found:
+ return {
+ 'path': path,
+ 'best': best_params,
+ 'state': state
+ }
+ else:
+ return False
+
+ def check_new_feature(self, weight):
+ # TODO checkout this change, checking if new feature is used
+ copied_weight = torch.tensor(weight.cpu())
+ used_features = torch.unique(
+ torch.nonzero(copied_weight)[:, 1])
+ if len(used_features) > 0:
+ new_set = set(used_features.tolist())
+ old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist())
+ diff = new_set - old_set
+ if len(diff) > 0:
+ self.selected_features[used_features] = True
+ return True
+ return False
+
+ def fit(self, feature_loaders, metadata, device):
+ # TODO checkout this change, glm saga code slightly adapted to return to_drop
+ print("Initializing linear model...")
+ linear = nn.Linear(self.num_features, self.n_classes).to(device)
+ for p in [linear.weight, linear.bias]:
+ p.data.zero_()
+
+ print("Preparing normalization preprocess and indexed dataloader")
+ preprocess = NormalizedRepresentation(feature_loaders['train'],
+ metadata=metadata,
+ device=linear.weight.device)
+
+ print("Calculating the regularization path")
+ mpl_logger = logging.getLogger("matplotlib")
+ mpl_logger.setLevel(logging.WARNING)
+ selected_features = self.glm_saga(linear,
+ feature_loaders['train'],
+ self.args.lr,
+ self.args.max_epochs,
+ self.selalpha, 0, 1,
+ val_loader=feature_loaders['val'],
+ test_loader=feature_loaders['test'],
+ n_classes=self.n_classes,
+ verbose=self.args.verbose,
+ tol=self.args.tol,
+ lookbehind=self.args.lookbehind,
+ lr_decay_factor=self.args.lr_decay_factor,
+ group=True,
+ epsilon=self.args.lam_factor,
+ metadata=metadata,
+ preprocess=preprocess, tot_tries=1)
+ to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0]
+ test_acc = selected_features["path"][-1]["metrics"]["acc_test"]
+ torch.set_grad_enabled(True)
+ return to_drop, test_acc
+
+
+class NormalizedRepresentation(ch.nn.Module):
+ def __init__(self, loader, metadata, device='cuda', tol=1e-5):
+ super(NormalizedRepresentation, self).__init__()
+
+ assert metadata is not None
+ self.device = device
+ self.mu = metadata['X']['mean']
+ self.sigma = ch.clamp(metadata['X']['std'], tol)
+
+ def forward(self, X):
+ return (X - self.mu.to(self.device)) / self.sigma.to(self.device)
+
+
+
+
diff --git a/sparsification/data_helpers.py b/sparsification/data_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d48424564050c66238f9b731b433ca25d29d5a6b
--- /dev/null
+++ b/sparsification/data_helpers.py
@@ -0,0 +1,16 @@
+
+import torch
+
+
+class NormalizedRepresentation(torch.nn.Module):
+ def __init__(self, loader, metadata, device='cuda', tol=1e-5):
+ super(NormalizedRepresentation, self).__init__()
+
+ assert metadata is not None
+ self.device = device
+ self.mu = metadata['X']['mean']
+ self.sigma = torch.clamp(metadata['X']['std'], tol)
+
+ def forward(self, X):
+ return (X - self.mu.to(self.device)) / self.sigma.to(self.device)
+
diff --git a/sparsification/feature_helpers.py b/sparsification/feature_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c11867077be5ab067548f498dad17fb299fa162
--- /dev/null
+++ b/sparsification/feature_helpers.py
@@ -0,0 +1,378 @@
+import math
+import os
+import sys
+
+import torch.cuda
+
+import sparsification.utils
+
+sys.path.append('')
+import numpy as np
+import torch as ch
+from torch.utils.data import Subset
+from tqdm import tqdm
+
+
+
+# From glm_saga
+def get_features_batch(batch, model, device='cuda'):
+ if not torch.cuda.is_available():
+ device = "cpu"
+ ims, targets = batch
+ output, latents = model(ims.to(device), with_final_features=True )
+ return latents, targets
+
+
+def compute_features(loader, model, dataset_type, pooled_output,
+ batch_size, num_workers,
+ shuffle=False, device='cpu', n_epoch=1,
+ filename=None, chunk_threshold=20000, balance=False):
+ """Compute deep features for a given dataset using a modeln and returnss
+ them as a pytorch dataset and loader.
+ Args:
+ loader : Torch data loader
+ model: Torch model
+ dataset_type (str): One of vision or language
+ pooled_output (bool): Whether or not to pool outputs
+ (only relevant for some language models)
+ batch_size (int): Batch size for output loader
+ num_workers (int): Number of workers to use for output loader
+ shuffle (bool): Whether or not to shuffle output data loaoder
+ device (str): Device on which to keep the model
+ filename (str):Optional file to cache computed feature. Recommended
+ for large dataset_classes like ImageNet.
+ chunk_threshold (int): Size of shard while caching
+ balance (bool): Whether or not to balance output data loader
+ (only relevant for some language models)
+ Returns:
+ feature_dataset: Torch dataset with deep features
+ feature_loader: Torch data loader with deep features
+ """
+ if torch.cuda.is_available():
+ device = "cuda"
+ print("mem_get_info before", torch.cuda.mem_get_info())
+ torch.cuda.empty_cache()
+ print("mem_get_info after", torch.cuda.mem_get_info())
+ model = model.to(device)
+ if filename is None or not os.path.exists(os.path.join(filename, f'0_features.npy')):
+ model.eval()
+ all_latents, all_targets, all_images = [], [], []
+ Nsamples, chunk_id = 0, 0
+ for idx_epoch in range(n_epoch):
+ for batch_idx, batch in tqdm(enumerate(loader), total=len(loader)):
+ with ch.no_grad():
+ latents, targets = get_features_batch(batch, model,
+ device=device)
+ if batch_idx == 0:
+ print("Latents shape", latents.shape)
+ Nsamples += latents.size(0)
+
+ all_latents.append(latents.cpu())
+ if len(targets.shape) > 1:
+ targets = targets[:, 0]
+ all_targets.append(targets.cpu())
+ # all_images.append(batch[0])
+ if filename is not None and Nsamples > chunk_threshold:
+ if not os.path.exists(filename): os.makedirs(filename)
+ np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy())
+ np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy())
+
+ all_latents, all_targets, Nsamples = [], [], 0
+ chunk_id += 1
+
+ if filename is not None and Nsamples > 0:
+ if not os.path.exists(filename): os.makedirs(filename)
+ np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy())
+ np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy())
+ # np.save(os.path.join(filename, f'{chunk_id}_images.npy'), ch.cat(all_images).numpy())
+ feature_dataset = load_features(filename) if filename is not None else \
+ ch.utils.data.TensorDataset(ch.cat(all_latents), ch.cat(all_targets))
+ if balance:
+ feature_dataset = balance_dataset(feature_dataset)
+
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=shuffle)
+
+ return feature_dataset, feature_loader
+
+
+def load_feature_loader(out_dir_feats, val_frac, batch_size, num_workers, random_seed):
+ feature_loaders = {}
+ for mode in ['train', 'test']:
+ print(f"For {mode} set...")
+ sink_path = f"{out_dir_feats}/features_{mode}"
+ metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"
+ feature_ds = load_features(sink_path)
+ feature_loader = ch.utils.data.DataLoader(feature_ds,
+ num_workers=num_workers,
+ batch_size=batch_size)
+ if mode == 'train':
+ metadata = calculate_metadata(feature_loader,
+ num_classes=2048,
+ filename=metadata_path)
+ split_datasets, split_loaders = split_dataset(feature_ds,
+ len(feature_ds),
+ val_frac=val_frac,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ random_seed=random_seed,
+ shuffle=True)
+ feature_loaders.update({mm: sparsification.utils.add_index_to_dataloader(split_loaders[mi])
+ for mi, mm in enumerate(['train', 'val'])})
+
+ else:
+ feature_loaders[mode] = feature_loader
+ return feature_loaders, metadata
+
+
+def balance_dataset(dataset):
+ """Balances a given dataset to have the same number of samples/class.
+ Args:
+ dataset : Torch dataset
+ Returns:
+ Torch dataset with equal number of samples/class
+ """
+
+ print("Balancing dataset...")
+ n = len(dataset)
+ labels = ch.Tensor([dataset[i][1] for i in range(n)]).int()
+ n0 = sum(labels).item()
+ I_pos = labels == 1
+
+ idx = ch.arange(n)
+ idx_pos = idx[I_pos]
+ ch.manual_seed(0)
+ I = ch.randperm(n - n0)[:n0]
+ idx_neg = idx[~I_pos][I]
+ idx_bal = ch.cat([idx_pos, idx_neg], dim=0)
+ return Subset(dataset, idx_bal)
+
+
+def load_metadata(feature_path):
+ return ch.load(os.path.join(feature_path, f'metadata_train.pth'))
+
+
+def get_mean_std(feature_path):
+ metadata = load_metadata(feature_path)
+ return metadata["X"]["mean"], metadata["X"]["std"]
+
+
+def load_features_dataset_mode(feature_path, mode='test',
+ num_workers=10, batch_size=128):
+ """Loads precomputed deep features corresponding to the
+ train/test set along with normalization statitic.
+ Args:
+ feature_path (str): Path to precomputed deep features
+ mode (str): One of train or tesst
+ num_workers (int): Number of workers to use for output loader
+ batch_size (int): Batch size for output loader
+
+ Returns:
+ features (np.array): Recovered deep features
+ feature_mean: Mean of deep features
+ feature_std: Standard deviation of deep features
+ """
+ feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=False)
+ feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth'))
+ feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std']
+ return feature_loader, feature_mean, feature_std
+
+
+def load_joint_dataset(feature_path, mode='test',
+ num_workers=10, batch_size=128):
+ feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=False)
+ features = []
+ labels = []
+ for _, (feature, label) in tqdm(enumerate(feature_loader), total=len(feature_loader)):
+ features.append(feature)
+ labels.append(label)
+ features = np.concatenate(features)
+ labels = np.concatenate(labels)
+ dataset = ch.utils.data.TensorDataset(torch.tensor(features), torch.tensor(labels))
+ return dataset
+
+
+def load_features_mode(feature_path, mode='test',
+ num_workers=10, batch_size=128):
+ """Loads precomputed deep features corresponding to the
+ train/test set along with normalization statitic.
+ Args:
+ feature_path (str): Path to precomputed deep features
+ mode (str): One of train or tesst
+ num_workers (int): Number of workers to use for output loader
+ batch_size (int): Batch size for output loader
+
+ Returns:
+ features (np.array): Recovered deep features
+ feature_mean: Mean of deep features
+ feature_std: Standard deviation of deep features
+ """
+ feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
+ feature_loader = ch.utils.data.DataLoader(feature_dataset,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=False)
+
+ feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth'))
+ feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std']
+
+ features = []
+
+ for _, (feature, _) in tqdm(enumerate(feature_loader), total=len(feature_loader)):
+ features.append(feature)
+
+ features = ch.cat(features).numpy()
+ return features, feature_mean, feature_std
+
+
+def load_features(feature_path):
+ """Loads precomputed deep features.
+ Args:
+ feature_path (str): Path to precomputed deep features
+
+ Returns:
+ Torch dataset with recovered deep features.
+ """
+ if not os.path.exists(os.path.join(feature_path, f"0_features.npy")):
+ raise ValueError(f"The provided location {feature_path} does not contain any representation files")
+
+ ds_list, chunk_id = [], 0
+ while os.path.exists(os.path.join(feature_path, f"{chunk_id}_features.npy")):
+ features = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_features.npy"))).float()
+ labels = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_labels.npy"))).long()
+ ds_list.append(ch.utils.data.TensorDataset(features, labels))
+ chunk_id += 1
+
+ print(f"==> loaded {chunk_id} files of representations...")
+ return ch.utils.data.ConcatDataset(ds_list)
+
+
+def calculate_metadata(loader, num_classes=None, filename=None):
+ """Calculates mean and standard deviation of the deep features over
+ a given set of images.
+ Args:
+ loader : torch data loader
+ num_classes (int): Number of classes in the dataset
+ filename (str): Optional filepath to cache metadata. Recommended
+ for large dataset_classes like ImageNet.
+
+ Returns:
+ metadata (dict): Dictionary with desired statistics.
+ """
+
+ if filename is not None and os.path.exists(filename):
+ print("loading Metadata from ", filename)
+ return ch.load(filename)
+
+ # Calculate number of classes if not given
+ if num_classes is None:
+ num_classes = 1
+ for batch in loader:
+ y = batch[1]
+ print(y)
+ num_classes = max(num_classes, y.max().item() + 1)
+
+ eye = ch.eye(num_classes)
+
+ X_bar, y_bar, y_max, n = 0, 0, 0, 0
+
+ # calculate means and maximum
+ print("Calculating means")
+ for ans in tqdm(loader, total=len(loader)):
+ X, y = ans[:2]
+ X_bar += X.sum(0)
+ y_bar += eye[y].sum(0)
+ y_max = max(y_max, y.max())
+ n += y.size(0)
+ X_bar = X_bar.float() / n
+ y_bar = y_bar.float() / n
+
+ # calculate std
+ X_std, y_std = 0, 0
+ print("Calculating standard deviations")
+ for ans in tqdm(loader, total=len(loader)):
+ X, y = ans[:2]
+ X_std += ((X - X_bar) ** 2).sum(0)
+ y_std += ((eye[y] - y_bar) ** 2).sum(0)
+ X_std = ch.sqrt(X_std.float() / n)
+ y_std = ch.sqrt(y_std.float() / n)
+
+ # calculate maximum regularization
+ inner_products = 0
+ print("Calculating maximum lambda")
+ for ans in tqdm(loader, total=len(loader)):
+ X, y = ans[:2]
+ y_map = (eye[y] - y_bar) / y_std
+ inner_products += X.t().mm(y_map) * y_std
+
+ inner_products_group = inner_products.norm(p=2, dim=1)
+
+ metadata = {
+ "X": {
+ "mean": X_bar,
+ "std": X_std,
+ "num_features": X.size()[1:],
+ "num_examples": n
+ },
+ "y": {
+ "mean": y_bar,
+ "std": y_std,
+ "num_classes": y_max + 1
+ },
+ "max_reg": {
+ "group": inner_products_group.abs().max().item() / n,
+ "nongrouped": inner_products.abs().max().item() / n
+ }
+ }
+
+ if filename is not None:
+ ch.save(metadata, filename)
+
+ return metadata
+
+
+def split_dataset(dataset, Ntotal, val_frac,
+ batch_size, num_workers,
+ random_seed=0, shuffle=True, balance=False):
+ """Splits a given dataset into train and validation
+ Args:
+ dataset : Torch dataset
+ Ntotal: Total number of dataset samples
+ val_frac: Fraction to reserve for validation
+ batch_size (int): Batch size for output loader
+ num_workers (int): Number of workers to use for output loader
+ random_seed (int): Random seed
+ shuffle (bool): Whether or not to shuffle output data loaoder
+ balance (bool): Whether or not to balance output data loader
+ (only relevant for some language models)
+
+ Returns:
+ split_datasets (list): List of dataset_classes (one each for train and val)
+ split_loaders (list): List of loaders (one each for train and val)
+ """
+
+ Nval = math.floor(Ntotal * val_frac)
+ train_ds, val_ds = ch.utils.data.random_split(dataset,
+ [Ntotal - Nval, Nval],
+ generator=ch.Generator().manual_seed(random_seed))
+ if balance:
+ val_ds = balance_dataset(val_ds)
+ split_datasets = [train_ds, val_ds]
+
+ split_loaders = []
+ for ds in split_datasets:
+ split_loaders.append(ch.utils.data.DataLoader(ds,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ shuffle=shuffle))
+ return split_datasets, split_loaders
diff --git a/sparsification/glmBasedSparsification.py b/sparsification/glmBasedSparsification.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a681147b4394281069c3a6bf0596baf435ecae0
--- /dev/null
+++ b/sparsification/glmBasedSparsification.py
@@ -0,0 +1,130 @@
+import logging
+import os
+import shutil
+
+import numpy as np
+import pandas as pd
+import torch
+from glm_saga.elasticnet import glm_saga
+from torch import nn
+
+from sparsification.FeatureSelection import FeatureSelectionFitting
+from sparsification import data_helpers
+from sparsification.utils import get_default_args, compute_features_and_metadata, select_in_loader, get_feature_loaders
+
+
+def get_glm_selection(feature_loaders, metadata, args, num_classes, device, n_features_to_select, folder):
+ num_features = metadata["X"]["num_features"][0]
+ fittingClass = FeatureSelectionFitting(num_features, num_classes, args, 0.8,
+ n_features_to_select,
+ 0.1,folder,
+ lookback=3, tol=1e-4,
+ epsilon=1,)
+ to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device)
+ selected_features = torch.tensor([i for i in range(num_features) if i not in to_drop])
+ return selected_features
+
+
+def compute_feature_selection_and_assignment(model, train_loader, test_loader, log_folder,num_classes, seed, select_features = 50):
+ feature_loaders, metadata, device,args = get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, )
+
+ if os.path.exists(log_folder / f"SlDD_Selection_{select_features}.pt"):
+ feature_selection = torch.load(log_folder / f"SlDD_Selection_{select_features}.pt")
+ else:
+ used_features = model.linear.weight.shape[1]
+ if used_features != select_features:
+ selection_folder = log_folder / "sldd_selection" # overwrite with None to prevent saving
+ feature_selection = get_glm_selection(feature_loaders, metadata, args,
+ num_classes,
+ device,select_features, selection_folder
+ )
+ else:
+ feature_selection = model.linear.selection
+ torch.save(feature_selection, log_folder / f"SlDD_Selection_{select_features}.pt")
+ feature_loaders = select_in_loader(feature_loaders, feature_selection)
+ mean, std = metadata["X"]["mean"], metadata["X"]["std"]
+ mean_to_pass_in = mean
+ std_to_pass_in = std
+ if len(mean) != feature_selection.shape[0]:
+ mean_to_pass_in = mean[feature_selection]
+ std_to_pass_in = std[feature_selection]
+
+ sparse_matrices, biases = fit_glm(log_folder, mean_to_pass_in, std_to_pass_in, feature_loaders, num_classes, select_features)
+
+ return feature_selection, sparse_matrices, biases, mean, std
+
+
+def fit_glm(log_dir,mean, std , feature_loaders, num_classes, select_features = 50):
+ output_folder = log_dir / "glm_path"
+ if not output_folder.exists() or len(list(output_folder.iterdir())) != 102:
+ shutil.rmtree(output_folder, ignore_errors=True)
+ output_folder.mkdir(exist_ok=True, parents=True)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ linear = nn.Linear(select_features, num_classes).to(device)
+ for p in [linear.weight, linear.bias]:
+ p.data.zero_()
+ print("Preparing normalization preprocess and indexed dataloader")
+ metadata = {"X": {"mean": mean, "std": std},}
+ preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'],
+ metadata=metadata,
+ device=linear.weight.device)
+
+ print("Calculating the regularization path")
+ mpl_logger = logging.getLogger("matplotlib")
+ mpl_logger.setLevel(logging.WARNING)
+ params = glm_saga(linear,
+ feature_loaders['train'],
+ 0.1,
+ 2000,
+ 0.99, k=100,
+ val_loader=feature_loaders['val'],
+ test_loader=feature_loaders['test'],
+ n_classes=num_classes,
+ checkpoint=str(output_folder),
+ verbose=200,
+ tol=1e-4, # Change for ImageNet
+ lookbehind=5,
+ lr_decay_factor=1,
+ group=False,
+ epsilon=0.001,
+ metadata=None, # To let it be recomputed
+ preprocess=preprocess, )
+ results = load_glm(output_folder)
+ sparse_matrices = results["weights"]
+ biases = results["biases"]
+
+ return sparse_matrices, biases
+
+def load_glm(result_dir):
+ Nlambda = max([int(f.split('params')[1].split('.pth')[0])
+ for f in os.listdir(result_dir) if 'params' in f]) + 1
+
+ print(f"Loading regularization path of length {Nlambda}")
+
+ params_dict = {i: torch.load(os.path.join(result_dir, f"params{i}.pth"),
+ map_location=torch.device('cpu')) for i in range(Nlambda)}
+
+ regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)]
+ weights = [params_dict[i]['weight'] for i in range(Nlambda)]
+ biases = [params_dict[i]['bias'] for i in range(Nlambda)]
+
+ metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []}
+
+ for k in metrics.keys():
+ for i in range(Nlambda):
+ metrics[k].append(params_dict[i]['metrics'][k])
+ metrics[k] = 100 * np.stack(metrics[k])
+ metrics = pd.DataFrame(metrics)
+ metrics = metrics.rename(columns={'acc_tr': 'acc_train'})
+
+ # weights_stacked = ch.stack(weights)
+ # sparsity = ch.sum(weights_stacked != 0, dim=2).numpy()
+ sparsity = np.array([torch.sum(w != 0, dim=1).numpy() for w in weights])
+
+ return {'metrics': metrics,
+ 'regularization_strengths': regularization_strengths,
+ 'weights': weights,
+ 'biases': biases,
+ 'sparsity': sparsity,
+ 'weight_dense': weights[-1],
+ 'bias_dense': biases[-1]}
diff --git a/sparsification/qsenn.py b/sparsification/qsenn.py
new file mode 100644
index 0000000000000000000000000000000000000000..45eb1bde64846d26996962f7a1c3b8d8e0ffa6ab
--- /dev/null
+++ b/sparsification/qsenn.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch
+
+from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
+
+
+def compute_qsenn_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed,n_features, per_class = 5):
+ feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ log_folder, num_classes, seed, n_features)
+ weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices[:-1], biases[:-1], per_class) # Last one in regularisation path has no regularisation
+ print(f"Number of nonzeros in weight matrix: {torch.sum(weight_sparse != 0)}")
+ return feature_sel, weight_sparse, bias_sparse, mean, std
+def get_sparsified_weights_for_factor(weights, biases, factor,):
+ no_reg_result_mat, no_reg_result_bias = weights[-1], biases[-1]
+ goal_nonzeros = factor * no_reg_result_mat.shape[0]
+ values = no_reg_result_mat.flatten()
+ values = values[values != 0]
+ values = -(torch.sort(-torch.abs(values))[0])
+ if goal_nonzeros < len(values):
+ threshold = (values[int(goal_nonzeros) - 1] + values[int(goal_nonzeros)]) / 2
+ else:
+ threshold = values[-1]
+ max_val = torch.max(torch.abs(values))
+ weight_sparse = discretize_2_bins_to_threshold(no_reg_result_mat, threshold, max_val)
+ sel_idx = len(weights) - 1
+ positive_weights_per_class = np.array(torch.sum(weight_sparse > 0, dim=1))
+ negative_weights_per_class = np.array(torch.sum(weight_sparse < 0, dim=1))
+ total_weight_count_per_class = positive_weights_per_class - negative_weights_per_class
+ max_bias = torch.max(torch.abs(biases[sel_idx]))
+ bias_sparse = torch.ones_like(biases[sel_idx]) * max_bias
+ diff_n_weight = total_weight_count_per_class - np.min(total_weight_count_per_class)
+ steps = np.max(diff_n_weight)
+ single_step = 2 * max_bias / steps
+ bias_sparse = bias_sparse - torch.tensor(diff_n_weight) * single_step
+ bias_sparse = torch.clamp(bias_sparse, -max_bias, max_bias)
+ return weight_sparse, bias_sparse
+
+
+def discretize_2_bins_to_threshold(data, treshold, max):
+ boundaries = torch.tensor([-max, -treshold, treshold, max], device=data.device)
+ bucketized_tensor = torch.bucketize(data, boundaries)
+ means = torch.tensor([-max, 0, max], device=data.device)
+ for i in range(len(means)):
+ if means[i] == 0:
+ break
+ positive_index = int(len(means) / 2 + 1) + i
+ positive_bucket = data[bucketized_tensor == positive_index + 1]
+ negative_bucket = data[bucketized_tensor == i + 1]
+ sum = 0
+ total = 0
+ for bucket in [positive_bucket, negative_bucket]:
+ if len(bucket) == 0:
+ continue
+ sum += torch.sum(torch.abs(bucket))
+ total += len(bucket)
+ if total == 0:
+ continue
+ avg = sum / total
+ means[i] = -avg
+ means[positive_index] = avg
+ discretized_tensor = means.cpu()[bucketized_tensor.cpu() - 1].to(bucketized_tensor.device)
+ return discretized_tensor
\ No newline at end of file
diff --git a/sparsification/sldd.py b/sparsification/sldd.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eeb3733797950107c6a0987fe242a0bfe2e732a
--- /dev/null
+++ b/sparsification/sldd.py
@@ -0,0 +1,44 @@
+import numpy as np
+import torch
+
+from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment
+
+
+def compute_sldd_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed,
+ per_class=5, select_features=50):
+ feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader,
+ test_loader,
+ log_folder, num_classes,
+ seed, select_features=select_features)
+ weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices,biases,
+ per_class) # Last one in regularisation path has none
+ return feature_sel, weight_sparse, bias_sparse, mean, std
+
+def get_sparsified_weights_for_factor(sparse_layer,biases,keep_per_class, drop_rate=0.5):
+ nonzero_entries = [torch.sum(torch.count_nonzero(sparse_layer[i])) for i in range(len(sparse_layer))]
+ mean_sparsity = np.array([nonzero_entries[i] / sparse_layer[i].shape[0] for i in range(len(sparse_layer))])
+ factor =keep_per_class / drop_rate
+ # Get layer with desired sparsity
+ sparse_enough = mean_sparsity <= factor
+ sel_idx = np.argmax(sparse_enough * mean_sparsity)
+ if sel_idx == 0 and np.sum(mean_sparsity) > 1: # sometimes first one is odd
+ sparse_enough[0] = False
+ sel_idx = np.argmax(sparse_enough * mean_sparsity)
+ selected_weight = sparse_layer[sel_idx]
+ selected_bias = biases[sel_idx]
+ # only keep 5 per class on average
+ weight_5_per_matrix = set_lowest_percent_to_zero(selected_weight,5)
+
+ return weight_5_per_matrix,selected_bias
+
+
+def set_lowest_percent_to_zero(matrix, keep_per):
+ nonzero_indices = torch.nonzero(matrix)
+ values = torch.tensor([matrix[x[0], x[1]] for x in nonzero_indices])
+ sorted_indices = torch.argsort(torch.abs(values))
+ total_allowed = int(matrix.shape[0] * keep_per)
+ sorted_indices = sorted_indices[:-total_allowed]
+ nonzero_indices_to_zero = [nonzero_indices[x] for x in sorted_indices]
+ for to_zero in nonzero_indices_to_zero:
+ matrix[to_zero[0], to_zero[1]] = 0
+ return matrix
\ No newline at end of file
diff --git a/sparsification/utils.py b/sparsification/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e960e5c4131032242e41bf7ff5a8a02b571fc89
--- /dev/null
+++ b/sparsification/utils.py
@@ -0,0 +1,159 @@
+from argparse import ArgumentParser
+
+import torch
+
+#from sparsification.glm_saga import glm_saga
+from sparsification import feature_helpers
+
+
+def safe_zip(*args):
+ for iterable in args[1:]:
+ if len(iterable) != len(args[0]):
+ print("Unequally sized iterables to zip, printing lengths")
+ for i, entry in enumerate(args):
+ print(i, len(entry))
+ raise ValueError("Unequally sized iterables to zip")
+ return zip(*args)
+
+
+def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes,
+ ):
+ print("Computing/loading deep features...")
+
+ Ntotal = len(train_loader.dataset)
+ feature_loaders = {}
+ # Compute Features for not augmented train and test set
+ train_loader_transforms = train_loader.dataset.transform
+ test_loader_transforms = test_loader.dataset.transform
+ train_loader.dataset.transform = test_loader_transforms
+ for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): #
+ print(f"For {mode} set...")
+
+ sink_path = f"{out_dir_feats}/features_{mode}"
+ metadata_path = f"{out_dir_feats}/metadata_{mode}.pth"
+
+ feature_ds, feature_loader = feature_helpers.compute_features(loader,
+ model,
+ dataset_type=args.dataset_type,
+ pooled_output=None,
+ batch_size=args.batch_size,
+ num_workers=0, # args.num_workers,
+ shuffle=(mode == 'test'),
+ device=args.device,
+ filename=sink_path, n_epoch=1,
+ balance=False,
+ ) # args.balance if mode == 'test' else False)
+
+ if mode == 'train':
+ metadata = feature_helpers.calculate_metadata(feature_loader,
+ num_classes=num_classes,
+ filename=metadata_path)
+ if metadata["max_reg"]["group"] == 0.0:
+ return None, False
+ split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds,
+ Ntotal,
+ val_frac=args.val_frac,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ random_seed=args.random_seed,
+ shuffle=True,
+ balance=False)
+ feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi])
+ for mi, mm in enumerate(['train', 'val'])})
+
+ else:
+ feature_loaders[mode] = feature_loader
+ train_loader.dataset.transform = train_loader_transforms
+ return feature_loaders, metadata
+
+def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ):
+ args = get_default_args()
+ args.random_seed = seed
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ feature_folder = log_folder / "features"
+ feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model,
+ feature_folder
+ ,
+ num_classes,
+ )
+ return feature_loaders, metadata, device,args
+def add_index_to_dataloader(loader, sample_weight=None,):
+ return torch.utils.data.DataLoader(
+ IndexedDataset(loader.dataset, sample_weight=sample_weight),
+ batch_size=loader.batch_size,
+ sampler=loader.sampler,
+ num_workers=loader.num_workers,
+ collate_fn=loader.collate_fn,
+ pin_memory=loader.pin_memory,
+ drop_last=loader.drop_last,
+ timeout=loader.timeout,
+ worker_init_fn=loader.worker_init_fn,
+ multiprocessing_context=loader.multiprocessing_context
+ )
+
+
+class IndexedDataset(torch.utils.data.Dataset):
+ def __init__(self, ds, sample_weight=None):
+ super(torch.utils.data.Dataset, self).__init__()
+ self.dataset = ds
+ self.sample_weight = sample_weight
+
+ def __getitem__(self, index):
+ val = self.dataset[index]
+ if self.sample_weight is None:
+ return val + (index,)
+ else:
+ weight = self.sample_weight[index]
+ return val + (weight, index)
+
+ def __len__(self):
+ return len(self.dataset)
+
+
+def get_default_args():
+ # Default args from glm_saga, https://github.com/MadryLab/glm_saga
+ parser = ArgumentParser()
+ parser.add_argument('--dataset', type=str, help='dataset name')
+ parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
+ parser.add_argument('--dataset-path', type=str, help='path to dataset')
+ parser.add_argument('--model-path', type=str, help='path to model checkpoint')
+ parser.add_argument('--arch', type=str, help='model architecture type')
+ parser.add_argument('--out-path', help='location for saving results')
+ parser.add_argument('--cache', action='store_true', help='cache deep features')
+ parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')
+
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--random-seed', default=0)
+ parser.add_argument('--num-workers', type=int, default=2)
+ parser.add_argument('--batch-size', type=int, default=256)
+ parser.add_argument('--val-frac', type=float, default=0.1)
+ parser.add_argument('--lr-decay-factor', type=float, default=1)
+ parser.add_argument('--lr', type=float, default=0.1)
+ parser.add_argument('--alpha', type=float, default=0.99)
+ parser.add_argument('--max-epochs', type=int, default=2000)
+ parser.add_argument('--verbose', type=int, default=200)
+ parser.add_argument('--tol', type=float, default=1e-4)
+ parser.add_argument('--lookbehind', type=int, default=3)
+ parser.add_argument('--lam-factor', type=float, default=0.001)
+ parser.add_argument('--group', action='store_true')
+ args = parser.parse_args()
+
+ args = parser.parse_args()
+ return args
+
+
+def select_in_loader(feature_loaders, feature_selection):
+ for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train
+ tensors = list(dataset.tensors)
+ if tensors[0].shape[1] == len(feature_selection):
+ continue
+ tensors[0] = tensors[0][:, feature_selection]
+ dataset.tensors = tensors
+ for dataset in feature_loaders["test"].dataset.datasets:
+ tensors = list(dataset.tensors)
+ if tensors[0].shape[1] == len(feature_selection):
+ continue
+ tensors[0] = tensors[0][:, feature_selection]
+ dataset.tensors = tensors
+ return feature_loaders
+
diff --git a/tmp/Datasets/CUB200/CUB_200_2011/README b/tmp/Datasets/CUB200/CUB_200_2011/README
new file mode 100644
index 0000000000000000000000000000000000000000..4cf4b8f6a8e963af922b4c320df3c9700af95914
--- /dev/null
+++ b/tmp/Datasets/CUB200/CUB_200_2011/README
@@ -0,0 +1,140 @@
+===========================================
+The Caltech-UCSD Birds-200-2011 Dataset
+===========================================
+
+For more information about the dataset, visit the project website:
+
+ http://www.vision.caltech.edu/visipedia
+
+If you use the dataset in a publication, please cite the dataset in
+the style described on the dataset website (see url above).
+
+Directory Information
+---------------------
+
+- images/
+ The images organized in subdirectories based on species. See
+ IMAGES AND CLASS LABELS section below for more info.
+- parts/
+ 15 part locations per image. See PART LOCATIONS section below
+ for more info.
+- attributes/
+ 322 binary attribute labels from MTurk workers. See ATTRIBUTE LABELS
+ section below for more info.
+
+
+
+=========================
+IMAGES AND CLASS LABELS:
+=========================
+Images are contained in the directory images/, with 200 subdirectories (one for each bird species)
+
+------- List of image files (images.txt) ------
+The list of image file names is contained in the file images.txt, with each line corresponding to one image:
+
+
+------------------------------------------
+
+
+------- Train/test split (train_test_split.txt) ------
+The suggested train/test split is contained in the file train_test_split.txt, with each line corresponding to one image:
+
+
+
+where corresponds to the ID in images.txt, and a value of 1 or 0 for denotes that the file is in the training or test set, respectively.
+------------------------------------------------------
+
+
+------- List of class names (classes.txt) ------
+The list of class names (bird species) is contained in the file classes.txt, with each line corresponding to one class:
+
+
+--------------------------------------------
+
+
+------- Image class labels (image_class_labels.txt) ------
+The ground truth class labels (bird species labels) for each image are contained in the file image_class_labels.txt, with each line corresponding to one image:
+
+
+
+where and correspond to the IDs in images.txt and classes.txt, respectively.
+---------------------------------------------------------
+
+
+
+
+
+=========================
+BOUNDING BOXES:
+=========================
+
+Each image contains a single bounding box label. Bounding box labels are contained in the file bounding_boxes.txt, with each line corresponding to one image:
+
+
+
+where corresponds to the ID in images.txt, and , , , and are all measured in pixels
+
+
+
+
+=========================
+PART LOCATIONS:
+=========================
+
+------- List of part names (parts/parts.txt) ------
+The list of all part names is contained in the file parts/parts.txt, with each line corresponding to one part:
+
+
+------------------------------------------
+
+
+------- Part locations (parts/part_locs.txt) ------
+The set of all ground truth part locations is contained in the file parts/part_locs.txt, with each line corresponding to the annotation of a particular part in a particular image:
+
+
+
+where and correspond to the IDs in images.txt and parts/parts.txt, respectively. and denote the pixel location of the center of the part. is 0 if the part is not visible in the image and 1 otherwise.
+----------------------------------------------------------
+
+
+------- MTurk part locations (parts/part_click_locs.txt) ------
+A set of multiple part locations for each image and part, as perceived by multiple MTurk users is contained in parts/part_click_locs.txt, with each line corresponding to the annotation of a particular part in a particular image by a different MTurk worker:
+
+