diff --git a/DIC.py b/DIC.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd67353e13c054ea320882506624ac0e2050a91 --- /dev/null +++ b/DIC.py @@ -0,0 +1,17 @@ +import torch +from pathlib import Path + + +dir=Path.home() / f"tmp/resnet50/CUB2011/123456/" +dic=torch.load(dir/ f"SlDD_Selection_50.pt") + +print (dic) + +#if 'linear.selection' in dic.keys(): + #print("key 'linear.selection' exist") +#else: + #print("no such key") + + + + diff --git a/FeatureDiversityLoss.py b/FeatureDiversityLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..be5745ae71dbe298244271c3a942c80c2b3e9867 --- /dev/null +++ b/FeatureDiversityLoss.py @@ -0,0 +1,59 @@ +import torch +from torch import nn + +""" +Feature Diversity Loss: +Usage to replicate paper: +Call +loss_function = FeatureDiversityLoss(0.196, linear) +to inititalize loss with linear layer of model. +At each mini batch get feature maps (Output of final convolutional layer) and add to Loss: +loss += loss_function(feature_maps, outputs) +""" + + +class FeatureDiversityLoss(nn.Module): + def __init__(self, scaling_factor, linear): + super().__init__() + self.scaling_factor = scaling_factor #* 0 + print("Scaling Factor: ", self.scaling_factor) + self.linearLayer = linear + + def initialize(self, linearLayer): + self.linearLayer = linearLayer + + def get_weights(self, outputs): + weight_matrix = self.linearLayer.weight + weight_matrix = torch.abs(weight_matrix) + top_classes = torch.argmax(outputs, dim=1) + relevant_weights = weight_matrix[top_classes] + return relevant_weights + + def forward(self, feature_maps, outputs): + relevant_weights = self.get_weights(outputs) + relevant_weights = norm_vector(relevant_weights) + feature_maps = preserve_avg_func(feature_maps) + flattened_feature_maps = feature_maps.flatten(2) + batch, features, map_size = flattened_feature_maps.size() + relevant_feature_maps = flattened_feature_maps * relevant_weights[..., None] + diversity_loss = torch.sum( + torch.amax(relevant_feature_maps, dim=1)) + return -diversity_loss / batch * self.scaling_factor + + +def norm_vector(x): + return x / (torch.norm(x, dim=1) + 1e-5)[:, None] + + +def preserve_avg_func(x): + avgs = torch.mean(x, dim=[2, 3]) + max_avgs = torch.max(avgs, dim=1)[0] + scaling_factor = avgs / torch.clamp(max_avgs[..., None], min=1e-6) + softmaxed_maps = softmax_feature_maps(x) + scaled_maps = softmaxed_maps * scaling_factor[..., None, None] + return scaled_maps + + +def softmax_feature_maps(x): + return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x) + diff --git a/__pycache__/get_data.cpython-310.pyc b/__pycache__/get_data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db7e1ece40eba3119dbfd6595168c1599a3847bb Binary files /dev/null and b/__pycache__/get_data.cpython-310.pyc differ diff --git a/__pycache__/load_model.cpython-310.pyc b/__pycache__/load_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dc4bd3e12a89cd304e7c33e8cf190a0353a8698 Binary files /dev/null and b/__pycache__/load_model.cpython-310.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d8dab90f6b16cfbfcba731f336feb59e92c9f0b7 --- /dev/null +++ b/app.py @@ -0,0 +1,143 @@ +import gradio as gr +from load_model import extract_sel_mean_std_bias_assignemnt +from pathlib import Path +from architectures.model_mapping import get_model +from configs.dataset_params import dataset_constants +import torch +import torchvision.transforms as transforms +import pandas as pd +import cv2 +import numpy as np + +def overlapping_features_on_input(model,output, feature_maps, input, target): + W=model.linear.layer.weight + output=output.detach().cpu().numpy() + feature_maps=feature_maps.detach().cpu().numpy().squeeze() + + if target !=None: + label=target + else: + label=np.argmax(output)+1 + + Interpretable_Selection= W[label,:] + print("W",Interpretable_Selection) + input_np=np.array(input) + h,w= input.shape[:2] + print("h,w:",h,w) + Interpretable_Features=[] + Feature_image_list=[] + for S in range(len(Interpretable_Selection)): + if Interpretable_Selection[S] > 0: + Interpretable_Features.append(feature_maps[S]) + Feature_image=cv2.resize(feature_maps[S],(w,h)) + Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255 + Feature_image=Feature_image.astype(np.uint8) + Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET) + Feature_image=0.3*Feature_image+0.7*input_np + Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8) + Feature_image_list.append(Feature_image) + #path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg" + #cv2.imwrite(path_to_featureimage,Feature_image) + print("len of Features:",len(Interpretable_Features)) + + return Feature_image_list + + +def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): + n_classes = dataset_constants[dataset]["num_classes"] + + model = get_model(arch, n_classes, reduced_strides) + tr=transforms.ToTensor() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if folder is None: + folder = Path(f"tmp/{arch}/{dataset}/{seed}/") + + state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") + selection= torch.load(folder / f"SlDD_Selection_50.pt") + state_dict['linear.selection']=selection + + feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) + model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) + model.load_state_dict(state_dict) + + input = tr(input) + input= input.unsqueeze(0) + input= input.to(device) + model = model.to(device) + output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) + print("final features:",final_features) + output=output.detach().cpu().numpy() + output= np.argmax(output)+1 + + + print("outputclass:",output) + data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") + labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target']) + namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) + classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) + options_output=labels[labels['target']==output] + options_output=options_output.sample(1) + others=labels[labels['target']!=output] + options_others=others.sample(3) + options = pd.concat([options_others, options_output], ignore_index=True) + shuffled_options = options.sample(frac=1).reset_index(drop=True) + print("shuffled:",shuffled_options) + op=[] + + for i in shuffled_options['img_id']: + print(i) + filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] + targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0] + print("targets",targets) + print("name",filenames) + + classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] + print(data_dir/f"images/{filenames}") + + op_img=cv2.imread(data_dir/f"images/{filenames}") + + op_images=tr(op_img) + op_images=op_images.unsqueeze(0) + op_images=op_images.to(device) + OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) + print("OP:",OP, + "feature_maps_op:",feature_maps_op.shape) + opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) + op+=opt + + return op + +def post_next_image(op): + if len(op)<=1: + return [],None, "all done, thank you!" + else: + op=op[1:len(op)] + return op,op[0], "Is this feature also in your input?" + +def get_features_on_interface(input): + op=genreate_intepriable_output(input,dataset="CUB2011", + arch="resnet50",seed=123456, + model_type="qsenn", n_features = 50,n_per_class=5, + img_size=448, reduced_strides=False, folder = None) + return op, op[0],"Is this feature also in your input?",gr.update(interactive=False) + + +with gr.Blocks() as demo: + + gr.Markdown("

Interiable Bird Classification

") + image_input=gr.Image() + image_output=gr.Image() + text_output=gr.Markdown() + but_generate=gr.Button("Get some interpriable Features") + but_feedback_y=gr.Button("Yes") + but_feedback_n=gr.Button("No") + image_list = gr.State([]) + but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate]) + but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output]) + but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output]) + +demo.launch() + + + + \ No newline at end of file diff --git a/architectures/FinalLayer.py b/architectures/FinalLayer.py new file mode 100644 index 0000000000000000000000000000000000000000..af1a55a667c462ec8f256f9d28aefdc5e77d6cae --- /dev/null +++ b/architectures/FinalLayer.py @@ -0,0 +1,36 @@ +import torch +from torch import nn + +from architectures.SLDDLevel import SLDDLevel + + +class FinalLayer(): + def __init__(self, num_classes, n_features): + super().__init__() + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(n_features, num_classes) + self.featureDropout = torch.nn.Dropout(0.2) + self.selection = None + + def transform_output(self, feature_maps, with_feature_maps, + with_final_features): + if self.selection is not None: + feature_maps = feature_maps[:, self.selection] + x = self.avgpool(feature_maps) + pre_out = torch.flatten(x, 1) + final_features = self.featureDropout(pre_out) + final = self.linear(final_features) + final = [final] + if with_feature_maps: + final.append(feature_maps) + if with_final_features: + final.append(final_features) + if len(final) == 1: + final = final[0] + return final + + + def set_model_sldd(self, selection, weight, mean, std, bias = None): + self.selection = selection + self.linear = SLDDLevel(selection, weight, mean, std, bias) + self.featureDropout = torch.nn.Dropout(0.1) \ No newline at end of file diff --git a/architectures/SLDDLevel.py b/architectures/SLDDLevel.py new file mode 100644 index 0000000000000000000000000000000000000000..bc214c88f384690d29bda97d7ed82a8c01e866da --- /dev/null +++ b/architectures/SLDDLevel.py @@ -0,0 +1,37 @@ +import torch.nn + + +class SLDDLevel(torch.nn.Module): + def __init__(self, selection, weight_at_selection,mean, std, bias=None): + super().__init__() + self.register_buffer('selection', torch.tensor(selection, dtype=torch.long)) + num_classes, n_features = weight_at_selection.shape + selected_mean = mean + selected_std = std + if len(selected_mean) != len(selection): + selected_mean = selected_mean[selection] + selected_std = selected_std[selection] + self.mean = torch.nn.Parameter(selected_mean) + self.std = torch.nn.Parameter(selected_std) + if bias is not None: + self.layer = torch.nn.Linear(n_features, num_classes) + self.layer.bias = torch.nn.Parameter(bias, requires_grad=False) + else: + self.layer = torch.nn.Linear(n_features, num_classes, bias=False) + self.layer.weight = torch.nn.Parameter(weight_at_selection, requires_grad=False) + + @property + def weight(self): + return self.layer.weight + + @property + def bias(self): + if self.layer.bias is None: + return torch.zeros(self.layer.out_features) + else: + return self.layer.bias + + + def forward(self, input): + input = (input - self.mean) / torch.clamp(self.std, min=1e-6) + return self.layer(input) diff --git a/architectures/__pycache__/FinalLayer.cpython-310.pyc b/architectures/__pycache__/FinalLayer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a61e00b4a63ecca65185c2f0157f001dbead798 Binary files /dev/null and b/architectures/__pycache__/FinalLayer.cpython-310.pyc differ diff --git a/architectures/__pycache__/SLDDLevel.cpython-310.pyc b/architectures/__pycache__/SLDDLevel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5d50a5be90c92f0840009d5a78ce7a4a4821df Binary files /dev/null and b/architectures/__pycache__/SLDDLevel.cpython-310.pyc differ diff --git a/architectures/__pycache__/model_mapping.cpython-310.pyc b/architectures/__pycache__/model_mapping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..515201d4169359545fb87981950d8c22a4181b40 Binary files /dev/null and b/architectures/__pycache__/model_mapping.cpython-310.pyc differ diff --git a/architectures/__pycache__/resnet.cpython-310.pyc b/architectures/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93dccc8da3700ae7fdcbfe887f868ef16d935b95 Binary files /dev/null and b/architectures/__pycache__/resnet.cpython-310.pyc differ diff --git a/architectures/__pycache__/utils.cpython-310.pyc b/architectures/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a224141ff65b31e88257f416c3f7bfc2aaafbfa Binary files /dev/null and b/architectures/__pycache__/utils.cpython-310.pyc differ diff --git a/architectures/model_mapping.py b/architectures/model_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..52df91009029b653b420ff03562616b2389eaa68 --- /dev/null +++ b/architectures/model_mapping.py @@ -0,0 +1,7 @@ +from architectures.resnet import resnet50 + + +def get_model(arch, num_classes, changed_strides=True): + if arch == "resnet50": + model = resnet50(True, num_classes=num_classes, changed_strides=changed_strides) + return model \ No newline at end of file diff --git a/architectures/resnet.py b/architectures/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..eaaa5d3c22e6ab85f9ac63b29462d20aec9594d3 --- /dev/null +++ b/architectures/resnet.py @@ -0,0 +1,420 @@ +import copy +import time + +import torch +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from torchvision.models import get_model + +# from scripts.modelExtensions.crossModelfunctions import init_experiment_stuff + + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2', + 'wide_resnet50_3', 'wide_resnet50_4', 'wide_resnet50_5', + 'wide_resnet50_6', ] + +from architectures.FinalLayer import FinalLayer +from architectures.utils import SequentialWithArgs + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, features=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + + def forward(self, x, no_relu=False): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + + + out += identity + + if no_relu: + return out + return self.relu(out) + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, features=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + if features is None: + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + else: + self.conv3 = conv1x1(width, features) + self.bn3 = norm_layer(features) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x, no_relu=False, early_exit=False): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + if no_relu: + return out + return self.relu(out) + + +class ResNet(nn.Module, FinalLayer): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, changed_strides=False,): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + widths = [64, 128, 256, 512] + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.sstride = 2 + if changed_strides: + self.sstride = 1 + self.layer3 = self._make_layer(block, 256, layers[2], stride=self.sstride, + dilate=replace_stride_with_dilation[1]) + self.stride = 2 + + if changed_strides: + self.stride = 1 + self.layer4 = self._make_layer(block, 512, layers[3], stride=self.stride, + dilate=replace_stride_with_dilation[2]) + FinalLayer.__init__(self, num_classes, 512 * block.expansion) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False, last_block_f=None): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + krepeep = None + if last_block_f is not None and _ == blocks - 1: + krepeep = last_block_f + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer, features=krepeep)) + + return SequentialWithArgs(*layers) + + def _forward(self, x, with_feature_maps=False, with_final_features=False): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + feature_maps = self.layer4(x, no_relu=True) + feature_maps = torch.functional.F.relu(feature_maps) + return self.transform_output( feature_maps, with_feature_maps, + with_final_features) + + # Allow for accessing forward method in a inherited class + forward = _forward + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + if kwargs["num_classes"] == 1000: + state_dict["linear.weight"] = state_dict["fc.weight"] + state_dict["linear.bias"] = state_dict["fc.bias"] + model.load_state_dict(state_dict, strict=False) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_3(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-3 model + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 3 + return _resnet('wide_resnet50_3', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_4(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-4 model + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 4 + return _resnet('wide_resnet50_4', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_5(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-5 model + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 5 + return _resnet('wide_resnet50_5', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_6(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-6 model + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 6 + return _resnet('wide_resnet50_6', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) diff --git a/architectures/utils.py b/architectures/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4cc78fc2799c675098bc73f0a9fb1719fb64b1 --- /dev/null +++ b/architectures/utils.py @@ -0,0 +1,17 @@ +import torch + + + +class SequentialWithArgs(torch.nn.Sequential): + def forward(self, input, *args, **kwargs): + vs = list(self._modules.values()) + l = len(vs) + for i in range(l): + if i == l-1: + input = vs[i](input, *args, **kwargs) + else: + input = vs[i](input) + return input + + + diff --git a/configs/__pycache__/dataset_params.cpython-310.pyc b/configs/__pycache__/dataset_params.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f91c9298daab0217462fe2590a20695b80575e85 Binary files /dev/null and b/configs/__pycache__/dataset_params.cpython-310.pyc differ diff --git a/configs/__pycache__/optim_params.cpython-310.pyc b/configs/__pycache__/optim_params.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0b3f696299233ff7e5b5d5936ea47d83c8f8b97 Binary files /dev/null and b/configs/__pycache__/optim_params.cpython-310.pyc differ diff --git a/configs/architecture_params.py b/configs/architecture_params.py new file mode 100644 index 0000000000000000000000000000000000000000..21a5f2b7fb72dfbee2b4487cabca2d2e840ad938 --- /dev/null +++ b/configs/architecture_params.py @@ -0,0 +1 @@ +architecture_params = {"resnet50": {"beta":0.196}} \ No newline at end of file diff --git a/configs/dataset_params.py b/configs/dataset_params.py new file mode 100644 index 0000000000000000000000000000000000000000..3f227da674de51bf3c4ac0fe3a8faff2004775a6 --- /dev/null +++ b/configs/dataset_params.py @@ -0,0 +1,22 @@ +import torch + +from configs.optim_params import EvaluatedDict + +dataset_constants = {"CUB2011":{"num_classes":200}, + "TravelingBirds":{"num_classes":200}, + "ImageNet":{"num_classes":1000}, + "StanfordCars":{"num_classes":196}, + "FGVCAircraft": {"num_classes":100}} + +normalize_params = {"CUB2011":{"mean": torch.tensor([0.4853, 0.4964, 0.4295]),"std":torch.tensor([0.2300, 0.2258, 0.2625])}, +"TravelingBirds":{"mean": torch.tensor([0.4584, 0.4369, 0.3957]),"std":torch.tensor([0.2610, 0.2569, 0.2722])}, + "ImageNet":{'mean': torch.tensor([0.485, 0.456, 0.406]),'std': torch.tensor([0.229, 0.224, 0.225])} , +"StanfordCars":{'mean': torch.tensor([0.4593, 0.4466, 0.4453]),'std': torch.tensor([0.2920, 0.2910, 0.2988])} , + "FGVCAircraft":{'mean': torch.tensor([0.4827, 0.5130, 0.5352]), + 'std': torch.tensor([0.2236, 0.2170, 0.2478]),} + } + + +dense_batch_size = EvaluatedDict({False: 16,True: 1024,}, lambda x: x == "ImageNet") + +ft_batch_size = EvaluatedDict({False: 16,True: 1024,}, lambda x: x == "ImageNet")# Untested \ No newline at end of file diff --git a/configs/optim_params.py b/configs/optim_params.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fad011caec798b4d51948b28b4d0885c414b59 --- /dev/null +++ b/configs/optim_params.py @@ -0,0 +1,22 @@ +# order: lr,weight_decay, step_lr, step_lr_gamma +import math + + +class EvaluatedDict: + def __init__(self, d, func): + self.dict = d + self.func = func + + def __getitem__(self, key): + return self.dict[self.func(key)] + +dense_params = EvaluatedDict({False: [0.005, 0.0005, 30, 0.4, 150],True: [None,None,None,None,None],}, lambda x: x == "ImageNet") +def calculate_lr_from_args( epochs, step_lr, start_lr, step_lr_decay): + # Gets the final learning rate after dense training with step_lr_schedule. + n_steps = math.floor((epochs - step_lr) / step_lr) + final_lr = start_lr * step_lr_decay ** n_steps + return final_lr + +ft_params =EvaluatedDict({False: [1e-4, 0.0005, 10, 0.4, 40],True:[[calculate_lr_from_args(150,30,0.005, 0.4), 0.0005, 10, 0.4, 40]]}, lambda x: x == "ImageNet") + + diff --git a/configs/qsenn_training_params.py b/configs/qsenn_training_params.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca03c994ee04c47016c89357ff5d4953f634281 --- /dev/null +++ b/configs/qsenn_training_params.py @@ -0,0 +1,11 @@ +from configs.sldd_training_params import OptimizationScheduler + + +class QSENNScheduler(OptimizationScheduler): + def get_params(self): + params = super().get_params() + if self.n_calls >= 2: + params[0] = params[0] * 0.9**(self.n_calls-2) + if 2 <= self.n_calls <= 4: + params[-2] = 10# Change num epochs to 10 for iterative finetuning + return params diff --git a/configs/sldd_training_params.py b/configs/sldd_training_params.py new file mode 100644 index 0000000000000000000000000000000000000000..5a605602a1a399d0dd55e1f53d8cbaa8c5d73dc0 --- /dev/null +++ b/configs/sldd_training_params.py @@ -0,0 +1,17 @@ +from configs.optim_params import dense_params, ft_params + + +class OptimizationScheduler: + def __init__(self, dataset): + self.dataset = dataset + self.n_calls = 0 + + + def get_params(self): + if self.n_calls == 0: # Return Deńse Params + params = dense_params[self.dataset]+ [False] + else: # Return Finetuning Params + params = ft_params[self.dataset]+ [True] + self.n_calls += 1 + return params + diff --git a/dataset_classes/__pycache__/cub200.cpython-310.pyc b/dataset_classes/__pycache__/cub200.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b65386a367820ce47ae2ecf095fb28397d58df2a Binary files /dev/null and b/dataset_classes/__pycache__/cub200.cpython-310.pyc differ diff --git a/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc b/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e54b5188bd94de8a82b0639237f4ddf557cd52a Binary files /dev/null and b/dataset_classes/__pycache__/stanfordcars.cpython-310.pyc differ diff --git a/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc b/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cc7e30e2027a3730c1873443049962b75998fb0 Binary files /dev/null and b/dataset_classes/__pycache__/travelingbirds.cpython-310.pyc differ diff --git a/dataset_classes/__pycache__/utils.cpython-310.pyc b/dataset_classes/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb96db5a1c1e6f6cd7e6efc268cffd6d6e0004b Binary files /dev/null and b/dataset_classes/__pycache__/utils.cpython-310.pyc differ diff --git a/dataset_classes/cub200.py b/dataset_classes/cub200.py new file mode 100644 index 0000000000000000000000000000000000000000..b59a933605948ed45365ccba82486c2c433d4173 --- /dev/null +++ b/dataset_classes/cub200.py @@ -0,0 +1,96 @@ +# Dataset should lie under /root/ +# root is currently set to ~/tmp/Datasets/CUB200 +# If cropped iamges, like for PIP-Net, ProtoPool, etc. are used, then the crop_root should be set to a folder containing the +# cropped images in the expected structure, obtained by following ProtoTree's instructions. +# https://github.com/M-Nauta/ProtoTree/blob/main/README.md#preprocessing-cub +import os +from pathlib import Path + +import numpy as np +import pandas as pd +from torch.utils.data import Dataset +from torchvision.datasets.folder import default_loader + +from dataset_classes.utils import txt_load + + +class CUB200Class(Dataset): + root = Path.home() / "tmp/Datasets/CUB200" + crop_root = Path.home() / "tmp/Datasets/PPCUB200" + base_folder = 'CUB_200_2011/images' + def __init__(self, train, transform, crop=True): + self.train = train + self.transform = transform + self.crop = crop + self._load_metadata() + self.loader = default_loader + + if crop: + self.adapt_to_crop() + + def _load_metadata(self): + images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', + names=['img_id', 'filepath']) + image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), + sep=' ', names=['img_id', 'target']) + train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), + sep=' ', names=['img_id', 'is_training_img']) + data = images.merge(image_class_labels, on='img_id') + self.data = data.merge(train_test_split, on='img_id') + if self.train: + self.data = self.data[self.data.is_training_img == 1] + else: + self.data = self.data[self.data.is_training_img == 0] + + def __len__(self): + return len(self.data) + + def adapt_to_crop(self): + # ds_name = [x for x in self.cropped_dict.keys() if x in self.root][0] + self.root = self.crop_root + folder_name = "train" if self.train else "test" + folder_name = folder_name + "_cropped" + self.base_folder = 'CUB_200_2011/' + folder_name + + def __getitem__(self, idx): + sample = self.data.iloc[idx] + path = os.path.join(self.root, self.base_folder, sample.filepath) + target = sample.target - 1 # Targets start at 1 by default, so shift to 0 + img = self.loader(path) + img = self.transform(img) + return img, target + + @classmethod + def get_image_attribute_labels(self, train=False): + image_attribute_labels = pd.read_csv( + os.path.join('/home/norrenbr/tmp/Datasets/CUB200', 'CUB_200_2011', "attributes", + 'image_attribute_labels.txt'), + sep=' ', names=['img_id', 'attribute', "is_present", "certainty", "time"], on_bad_lines="skip") + train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), + sep=' ', names=['img_id', 'is_training_img']) + merged = image_attribute_labels.merge(train_test_split, on="img_id") + filtered_data = merged[merged["is_training_img"] == train] + return filtered_data + + + @staticmethod + def filter_attribute_labels(labels, min_certainty=3): + is_invisible_present = labels[labels["certainty"] == 1]["is_present"].sum() + if is_invisible_present != 0: + raise ValueError("Invisible present") + labels["img_id"] -= min(labels["img_id"]) + labels["img_id"] = fillholes_in_array(labels["img_id"]) + labels[labels["certainty"] == 1]["certainty"] = 4 + labels = labels[labels["certainty"] >= min_certainty] + labels["attribute"] -= min(labels["attribute"]) + labels = labels[["img_id", "attribute", "is_present"]] + labels["is_present"] = labels["is_present"].astype(bool) + return labels + + + +def fillholes_in_array(array): + unique_values = np.unique(array) + mapping = {x: i for i, x in enumerate(unique_values)} + array = array.map(mapping) + return array diff --git a/dataset_classes/stanfordcars.py b/dataset_classes/stanfordcars.py new file mode 100644 index 0000000000000000000000000000000000000000..0be682a5d164a8b39cff5bd9cca82cc8cf5ebe53 --- /dev/null +++ b/dataset_classes/stanfordcars.py @@ -0,0 +1,121 @@ +import pathlib +from typing import Callable, Optional, Any, Tuple + +import numpy as np +import pandas as pd +from PIL import Image +from torchvision.datasets import VisionDataset +from torchvision.datasets.utils import download_and_extract_archive, download_url + + +class StanfordCarsClass(VisionDataset): + """`Stanford Cars `_ Dataset + + The Cars dataset contains 16,185 images of 196 classes of cars. The data is + split into 8,144 training images and 8,041 testing images, where each class + has been split roughly in a 50-50 split + + .. note:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (string): Root directory of dataset + split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again.""" + root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars" + def __init__( + self, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, + ) -> None: + + try: + import scipy.io as sio + except ImportError: + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") + + super().__init__(self.root, transform=transform, target_transform=target_transform) + + self.train = train + self._base_folder = pathlib.Path(self.root) / "stanford_cars" + devkit = self._base_folder / "devkit" + + if train: + self._annotations_mat_path = devkit / "cars_train_annos.mat" + self._images_base_path = self._base_folder / "cars_train" + else: + self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" + self._images_base_path = self._base_folder / "cars_test" + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self.samples = [ + ( + str(self._images_base_path / annotation["fname"]), + annotation["class"] - 1, # Original target mapping starts from 1, hence -1 + ) + for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] + ] + self.targets = np.array([x[1] for x in self.samples]) + self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + """Returns pil_image and class_id for given index""" + image_path, target = self.samples[idx] + pil_image = Image.open(image_path).convert("RGB") + + if self.transform is not None: + pil_image = self.transform(pil_image) + if self.target_transform is not None: + target = self.target_transform(target) + return pil_image, target + + def download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", + download_root=str(self._base_folder), + md5="c3b158d763b6e2245038c8ad08e45376", + ) + if self.train: + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", + download_root=str(self._base_folder), + md5="065e5b463ae28d29e77c1b4b166cfe61", + ) + else: + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", + download_root=str(self._base_folder), + md5="4ce7ebf6a94d07f1952d94dd34c4d501", + ) + download_url( + url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", + root=str(self._base_folder), + md5="b0a2b23655a3edd16d84508592a98d10", + ) + + def _check_exists(self) -> bool: + if not (self._base_folder / "devkit").is_dir(): + return False + + return self._annotations_mat_path.exists() and self._images_base_path.is_dir() diff --git a/dataset_classes/travelingbirds.py b/dataset_classes/travelingbirds.py new file mode 100644 index 0000000000000000000000000000000000000000..551ce1fd46b9b84e572ea18f4adc6ecd73cea00d --- /dev/null +++ b/dataset_classes/travelingbirds.py @@ -0,0 +1,59 @@ +# TravelingBirds dataset needs to be downloaded from https://worksheets.codalab.org/bundles/0x518829de2aa440c79cd9d75ef6669f27 +# as it comes from https://github.com/yewsiang/ConceptBottleneck +import os +from pathlib import Path + +import numpy as np +import pandas as pd + +from dataset_classes.cub200 import CUB200Class +from dataset_classes.utils import index_list_with_sorting, mask_list + + +class TravelingBirds(CUB200Class): + init_base_folder = 'CUB_fixed' + root = Path.home() / "tmp/Datasets/TravelingBirds" + crop_root = Path.home() / "tmp/Datasets/PPTravelingBirds" + def get_all_samples_dir(self, dir): + + self.base_folder = os.path.join(self.init_base_folder, dir) + main_dir = Path(self.root) / self.init_base_folder / dir + return self.get_all_sample(main_dir) + + def adapt_to_crop(self): + self.root = self.crop_root + folder_name = "train" if self.train else "test" + folder_name = folder_name + "_cropped" + self.base_folder = 'CUB_fixed/' + folder_name + + def get_all_sample(self, dir): + answer = [] + for i, sub_dir in enumerate(sorted(os.listdir(dir))): + class_dir = dir / sub_dir + for single_img in os.listdir(class_dir): + answer.append([Path(sub_dir) / single_img, i + 1]) + return answer + def _load_metadata(self): + train_test_split = pd.read_csv( + os.path.join(Path(self.root).parent / "CUB200", 'CUB_200_2011', 'train_test_split.txt'), + sep=' ', names=['img_id', 'is_training_img']) + data = pd.read_csv( + os.path.join(Path(self.root).parent / "CUB200", 'CUB_200_2011', 'images.txt'), + sep=' ', names=['img_id', "path"]) + img_dict = {x[1]: x[0] for x in data.values} + # TravelingBirds has all train+test images in both folders, just with different backgrounds. + # They are separated by train_test_split of CUB200. + if self.train: + samples = self.get_all_samples_dir("train") + mask = train_test_split["is_training_img"] == 1 + else: + samples = self.get_all_samples_dir("test") + mask = train_test_split["is_training_img"] == 0 + ids = np.array([img_dict[str(x[0])] for x in samples]) + sorted = np.argsort(ids) + samples = index_list_with_sorting(samples, sorted) + samples = mask_list(samples, mask) + filepaths = [x[0] for x in samples] + labels = [x[1] for x in samples] + samples = pd.DataFrame({"filepath": filepaths, "target": labels}) + self.data = samples diff --git a/dataset_classes/utils.py b/dataset_classes/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0039ba93d0966230da88f2568d02bb7cebeebf --- /dev/null +++ b/dataset_classes/utils.py @@ -0,0 +1,16 @@ +def index_list_with_sorting(list_to_sort, sorting_list): + answer = [] + for entry in sorting_list: + answer.append(list_to_sort[entry]) + return answer + + +def mask_list(list_input, mask): + return [x for i, x in enumerate(list_input) if mask[i]] + + +def txt_load(filename): + with open(filename, 'r') as f: + data = f.read() + return data + diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..e4e2f7b3680115f3e38c80511baede60fda0db03 --- /dev/null +++ b/environment.yml @@ -0,0 +1,117 @@ +name: QSENNEnv +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli-python=1.0.9=py310h6a678d5_7 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.12.12=h06a4308_0 + - certifi=2023.11.17=py310h06a4308_0 + - cffi=1.16.0=py310h5eee18b_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cryptography=41.0.7=py310hdda0065_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.3.101=0 + - cuda-runtime=12.1.0=0 + - ffmpeg=4.3=hf484d3e_0 + - filelock=3.13.1=py310h06a4308_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py310heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - idna=3.4=py310h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jinja2=3.1.2=py310h06a4308_0 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.8.1.2=0 + - libcurand=10.3.4.107=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_0 + - markupsafe=2.1.3=py310h5eee18b_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py310h5eee18b_1 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py310h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.1=py310h06a4308_0 + - numpy=1.26.3=py310h5f9d8c6_0 + - numpy-base=1.26.3=py310hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.12=h7f8727e_0 + - pillow=10.0.1=py310ha6cbd5a_0 + - pip=23.3.1=py310h06a4308_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.2.0=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.13=h955ad1f_0 + - pytorch=2.1.2=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py310h5eee18b_0 + - readline=8.2=h5eee18b_0 + - requests=2.31.0=py310h06a4308_0 + - setuptools=68.2.2=py310h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - sympy=1.12=py310h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=2.1.2=py310_cu121 + - torchtriton=2.1.0=py310 + - torchvision=0.16.2=py310_cu121 + - typing_extensions=4.7.1=py310h06a4308_0 + - urllib3=1.26.18=py310h06a4308_0 + - wheel=0.41.2=py310h06a4308_0 + - xz=5.4.5=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - fsspec==2023.12.2 + - glm-saga==0.1.2 + - pandas==2.1.4 + - python-dateutil==2.8.2 + - pytz==2023.3.post1 + - six==1.16.0 + - tqdm==4.66.1 + - tzdata==2023.4 +prefix: /home/norrenbr/anaconda/tmp/envs/QSENN-Minimal diff --git a/evaluation/Metrics/Dependence.py b/evaluation/Metrics/Dependence.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1b26dfc19de0430925e38aac45ebcc33a94455 --- /dev/null +++ b/evaluation/Metrics/Dependence.py @@ -0,0 +1,21 @@ +import torch + + +def compute_contribution_top_feature(features, outputs, weights, labels): + with torch.no_grad(): + total_pre_softmax, predicted_classes = torch.max(outputs, dim=1) + feature_part = features * weights.to(features.device)[predicted_classes] + class_specific_feature_part = torch.zeros((weights.shape[0], features.shape[1],)) + feature_class_part = torch.zeros((weights.shape[0], features.shape[1],)) + for unique_class in predicted_classes.unique(): + mask = predicted_classes == unique_class + class_specific_feature_part[unique_class] = feature_part[mask].mean(dim=0) + gt_mask = labels == unique_class + feature_class_part[unique_class] = feature_part[gt_mask].mean(dim=0) + abs_features = feature_part.abs() + abs_sum = abs_features.sum(dim=1) + fractions_abs = abs_features / abs_sum[:, None] + abs_max = fractions_abs.max(dim=1)[0] + mask = ~torch.isnan(abs_max) + abs_max = abs_max[mask] + return abs_max.mean() \ No newline at end of file diff --git a/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc b/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d59c7cc91e9d0c3e84e533ca3876e3ee9850c52 Binary files /dev/null and b/evaluation/Metrics/__pycache__/Dependence.cpython-310.pyc differ diff --git a/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc b/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0de63438f9c10a07e3524cca02c53276c1d7622 Binary files /dev/null and b/evaluation/Metrics/__pycache__/cub_Alignment.cpython-310.pyc differ diff --git a/evaluation/Metrics/cub_Alignment.py b/evaluation/Metrics/cub_Alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4b41e427668f86ec530baab1796ac9d0678489 --- /dev/null +++ b/evaluation/Metrics/cub_Alignment.py @@ -0,0 +1,30 @@ +import numpy as np + +from dataset_classes.cub200 import CUB200Class + + +def get_cub_alignment_from_features(features_train_sorted): + metric_matrix = compute_metric_matrix(np.array(features_train_sorted), "train") + return np.mean(np.max(metric_matrix, axis=1)) + pass + + +def compute_metric_matrix(features, mode): + image_attribute_labels = CUB200Class.get_image_attribute_labels(train=mode == "train") + image_attribute_labels = CUB200Class.filter_attribute_labels(image_attribute_labels) + matrix_shape = ( + features.shape[1], max(image_attribute_labels["attribute"]) + 1) + accuracy_matrix = np.zeros(matrix_shape) + sensitivity_matrix = np.zeros_like(accuracy_matrix) + grouped_attributes = image_attribute_labels.groupby("attribute") + for attribute_id, group in grouped_attributes: + is_present = group[group["is_present"]] + not_present = group[~group["is_present"]] + is_present_avg = np.mean(features[is_present["img_id"]], axis=0) + not_present_avg = np.mean(features[not_present["img_id"]], axis=0) + sensitivity_matrix[:, attribute_id] = not_present_avg + accuracy_matrix[:, attribute_id] = is_present_avg + metric_matrix = accuracy_matrix - sensitivity_matrix + no_abs_features = features - np.min(features, axis=0) + no_abs_feature_mean = metric_matrix / no_abs_features.mean(axis=0)[:, None] + return no_abs_feature_mean diff --git a/evaluation/__pycache__/diversity.cpython-310.pyc b/evaluation/__pycache__/diversity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d65f89a660c7660f26ce7578e34557c87b970b66 Binary files /dev/null and b/evaluation/__pycache__/diversity.cpython-310.pyc differ diff --git a/evaluation/__pycache__/helpers.cpython-310.pyc b/evaluation/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f7034adbe8c30d420975414880bb581f2052080 Binary files /dev/null and b/evaluation/__pycache__/helpers.cpython-310.pyc differ diff --git a/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc b/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9460b6362b47ba4ddf699b74707b87b5a063ce73 Binary files /dev/null and b/evaluation/__pycache__/qsenn_metrics.cpython-310.pyc differ diff --git a/evaluation/__pycache__/utils.cpython-310.pyc b/evaluation/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e70bce079f72b6f83344a1decd675b1312655b6 Binary files /dev/null and b/evaluation/__pycache__/utils.cpython-310.pyc differ diff --git a/evaluation/diversity.py b/evaluation/diversity.py new file mode 100644 index 0000000000000000000000000000000000000000..033679ce9cf4546b74b0d1d4bdb6b8590c5c8865 --- /dev/null +++ b/evaluation/diversity.py @@ -0,0 +1,111 @@ +import numpy as np +import torch + +from evaluation.helpers import softmax_feature_maps + + +class MultiKCrossChannelMaxPooledSum: + def __init__(self, top_k_range, weights, interactions, func="softmax"): + self.top_k_range = top_k_range + self.weights = weights + self.failed = False + self.max_ks = self.get_max_ks(weights) + self.locality_of_used_features = torch.zeros(len(top_k_range), device=weights.device) + self.locality_of_exclusely_used_features = torch.zeros(len(top_k_range), device=weights.device) + self.ns_k = torch.zeros(len(top_k_range), device=weights.device) + self.exclusive_ns = torch.zeros(len(top_k_range), device=weights.device) + self.interactions = interactions + self.func = func + + def get_max_ks(self, weights): + nonzeros = torch.count_nonzero(torch.tensor(weights), 1) + return nonzeros + + def get_top_n_locality(self, outputs, initial_feature_maps, k): + feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, + initial_feature_maps) + max_ks = self.max_ks[top_classes] + max_k_based_row_selection = max_ks >= k + + result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps, + separated=True) + return result + + def get_locality(self, outputs, initial_feature_maps, n): + answer = self.get_top_n_locality(outputs, initial_feature_maps, n) + return answer + + def get_result(self): + # if torch.sum(self.exclusive_ns) ==0: + # end_idx = len(self.exclusive_ns) - 1 + # else: + + exclusive_array = torch.zeros_like(self.locality_of_exclusely_used_features) + local_array = torch.zeros_like(self.locality_of_used_features) + # if self.failed: + # return local_array, exclusive_array + cumulated = torch.cumsum(self.exclusive_ns, 0) + end_idx = torch.argmax(cumulated) + exclusivity_array = self.locality_of_exclusely_used_features[:end_idx + 1] / self.exclusive_ns[:end_idx + 1] + exclusivity_array[exclusivity_array != exclusivity_array] = 0 + exclusive_array[:len(exclusivity_array)] = exclusivity_array + locality_array = self.locality_of_used_features[self.locality_of_used_features != 0] / self.ns_k[ + self.locality_of_used_features != 0] + local_array[:len(locality_array)] = locality_array + return local_array, exclusive_array + + def get_crosspooled(self, relevant_weights, mask, k, vector_size, feature_maps, separated=False): + relevant_indices = get_relevant_indices(relevant_weights, k)[mask] + # this should have size batch x k x featuremapsize squared] + indices = relevant_indices.unsqueeze(2).repeat(1, 1, vector_size) + sub_feature_maps = torch.gather(feature_maps[mask], 1, indices) + # shape batch x featuremapsquared: For each "pixel" the highest value + cross_pooled = torch.max(sub_feature_maps, 1)[0] + if separated: + return torch.sum(cross_pooled, 1) / k + else: + ns = len(cross_pooled) + result = torch.sum(cross_pooled) / (k) + # should be batch x map size + + return ns, result + + def adapt_feature_maps(self, outputs, initial_feature_maps): + if self.func == "softmax": + feature_maps = softmax_feature_maps(initial_feature_maps) + feature_maps = torch.flatten(feature_maps, 2) + vector_size = feature_maps.shape[2] + top_classes = torch.argmax(outputs, dim=1) + relevant_weights = self.weights[top_classes] + if relevant_weights.shape[1] != feature_maps.shape[1]: + feature_maps = self.interactions.get_localized_features(initial_feature_maps) + feature_maps = softmax_feature_maps(feature_maps) + feature_maps = torch.flatten(feature_maps, 2) + return feature_maps, relevant_weights, vector_size, top_classes + + def calculate_locality(self, outputs, initial_feature_maps): + feature_maps, relevant_weights, vector_size, top_classes = self.adapt_feature_maps(outputs, + initial_feature_maps) + max_ks = self.max_ks[top_classes] + for k in self.top_k_range: + # relevant_k_s = max_ks[] + max_k_based_row_selection = max_ks >= k + if torch.sum(max_k_based_row_selection) == 0: + break + + exclusive_k = max_ks == k + if torch.sum(exclusive_k) != 0: + ns, result = self.get_crosspooled(relevant_weights, exclusive_k, k, vector_size, feature_maps) + self.locality_of_exclusely_used_features[k - 1] += result + self.exclusive_ns[k - 1] += ns + ns, result = self.get_crosspooled(relevant_weights, max_k_based_row_selection, k, vector_size, feature_maps) + self.ns_k[k - 1] += ns + self.locality_of_used_features[k - 1] += result + + def __call__(self, outputs, initial_feature_maps): + self.calculate_locality(outputs, initial_feature_maps) + + +def get_relevant_indices(weights, top_k): + top_k = weights.topk(top_k)[1] + return top_k \ No newline at end of file diff --git a/evaluation/helpers.py b/evaluation/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4a9902103fe63df01994acb079127ab719c9f1 --- /dev/null +++ b/evaluation/helpers.py @@ -0,0 +1,6 @@ +import torch + + +def softmax_feature_maps(x): + # done: verify that this applies softmax along first dimension + return torch.softmax(x.reshape(x.size(0), x.size(1), -1), 2).view_as(x) \ No newline at end of file diff --git a/evaluation/qsenn_metrics.py b/evaluation/qsenn_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1bb8f21b6f7dfe101c8e668e9c422c1d88ce8751 --- /dev/null +++ b/evaluation/qsenn_metrics.py @@ -0,0 +1,39 @@ +import numpy as np +import torch + +from evaluation.Metrics.Dependence import compute_contribution_top_feature +from evaluation.Metrics.cub_Alignment import get_cub_alignment_from_features +from evaluation.diversity import MultiKCrossChannelMaxPooledSum +from evaluation.utils import get_metrics_for_model + + +def evaluateALLMetricsForComps(features_train, outputs_train, feature_maps_test, + outputs_test, linear_matrix, labels_train): + with torch.no_grad(): + if len(features_train) < 7000: # recognize CUB and TravelingBirds + cub_alignment = get_cub_alignment_from_features(features_train) + else: + cub_alignment = 0 + print("cub_alignment: ", cub_alignment) + localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), linear_matrix, None) + batch_size = 300 + for i in range(np.floor(len(features_train) / batch_size).astype(int)): + localizer(outputs_test[i * batch_size:(i + 1) * batch_size].to("cuda"), + feature_maps_test[i * batch_size:(i + 1) * batch_size].to("cuda")) + + locality, exlusive_locality = localizer.get_result() + diversity = locality[4] + print("diversity@5: ", diversity) + abs_frac_mean = compute_contribution_top_feature( + features_train, + outputs_train, + linear_matrix, + labels_train) + print("Dependence ", abs_frac_mean) + answer_dict = {"diversity": diversity.item(), "Dependence": abs_frac_mean.item(), "Alignment":cub_alignment} + return answer_dict + +def eval_model_on_all_qsenn_metrics(model, test_loader, train_loader): + return get_metrics_for_model(train_loader, test_loader, model, evaluateALLMetricsForComps) + + diff --git a/evaluation/utils.py b/evaluation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1b679fc9dac88e2fb897d69c34d959c19b3101 --- /dev/null +++ b/evaluation/utils.py @@ -0,0 +1,57 @@ +import torch +from tqdm import tqdm + + + +def get_metrics_for_model(train_loader, test_loader, model, metric_evaluator): + (features_train, feature_maps_train, outputs_train, features_test, feature_maps_test, + outputs_test, labels) = [], [], [], [], [], [], [] + device = "cuda" if torch.cuda.is_available() else "cpu" + model.eval() + model = model.to(device) + training_transforms = train_loader.dataset.transform + train_loader.dataset.transform = test_loader.dataset.transform # Use test transform for train + train_loader = torch.utils.data.DataLoader(train_loader.dataset, batch_size=100, shuffle=False) # Turn off shuffling + print("Going in get metrics") + linear_matrix = model.linear.weight + entries = torch.nonzero(linear_matrix) + rel_features = torch.unique(entries[:, 1]) + with torch.no_grad(): + iterator = tqdm(enumerate(train_loader), total=len(train_loader)) + for batch_idx, (data, target) in iterator: + xs1 = data.to("cuda") + output, feature_maps, final_features = model(xs1, with_feature_maps=True, with_final_features=True,) + outputs_train.append(output.to("cpu")) + features_train.append(final_features.to("cpu")) + labels.append(target.to("cpu")) + total = 0 + correct = 0 + iterator = tqdm(enumerate(test_loader), total=len(test_loader)) + for batch_idx, (data, target) in iterator: + xs1 = data.to("cuda") + output, feature_maps, final_features = model(xs1, with_feature_maps=True, + with_final_features=True, ) + feature_maps_test.append(feature_maps[:, rel_features].to("cpu")) + outputs_test.append(output.to("cpu")) + total += target.size(0) + _, predicted = output.max(1) + correct += predicted.eq(target.to("cuda")).sum().item() + print("test accuracy: ", correct / total) + features_train = torch.cat(features_train) + outputs_train = torch.cat(outputs_train) + feature_maps_test = torch.cat(feature_maps_test) + outputs_test = torch.cat(outputs_test) + labels = torch.cat(labels) + linear_matrix = linear_matrix[:, rel_features] + print("Shape of linear matrix: ", linear_matrix.shape) + all_metrics_dict = metric_evaluator(features_train, outputs_train, + feature_maps_test, + outputs_test, linear_matrix, labels) + result_dict = {"Accuracy": correct / total, "NFfeatures": linear_matrix.shape[1], + "PerClass": torch.nonzero(linear_matrix).shape[0] / linear_matrix.shape[0], + } + result_dict.update(all_metrics_dict) + print(result_dict) + # Reset Train transforms + train_loader.dataset.transform = training_transforms + return result_dict diff --git a/fig/AutoML4FAS_Logo.jpeg b/fig/AutoML4FAS_Logo.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..35d4066fa5cf5967553960097b57f80c2ac8c580 Binary files /dev/null and b/fig/AutoML4FAS_Logo.jpeg differ diff --git a/fig/Bund.png b/fig/Bund.png new file mode 100644 index 0000000000000000000000000000000000000000..1c92a104515f9b3c61642f7cd3cc898163e5ef0e Binary files /dev/null and b/fig/Bund.png differ diff --git a/fig/LUH.png b/fig/LUH.png new file mode 100644 index 0000000000000000000000000000000000000000..af168ab3e866e5c66c616b6a090ef9c4ac212e3b Binary files /dev/null and b/fig/LUH.png differ diff --git a/fig/birds.png b/fig/birds.png new file mode 100644 index 0000000000000000000000000000000000000000..330ebdff52c39b989a5c0cd42e0a35fdbeb7c1ff Binary files /dev/null and b/fig/birds.png differ diff --git a/finetuning/map_function.py b/finetuning/map_function.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa65c3fa6dee0dc55484bdaae3fb181786eed1b --- /dev/null +++ b/finetuning/map_function.py @@ -0,0 +1,11 @@ +from finetuning.qsenn import finetune_qsenn +from finetuning.sldd import finetune_sldd + + +def finetune(key, model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule, per_class, n_features): + if key == 'sldd': + return finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,per_class, n_features) + elif key == 'qsenn': + return finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_features,per_class, ) + else: + raise ValueError(f"Unknown Finetuning key: {key}") \ No newline at end of file diff --git a/finetuning/qsenn.py b/finetuning/qsenn.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4dc8b65e6c703e51fe602c2ac897c97844897c --- /dev/null +++ b/finetuning/qsenn.py @@ -0,0 +1,30 @@ +import os + +import torch + +from finetuning.utils import train_n_epochs +from sparsification.qsenn import compute_qsenn_feature_selection_and_assignment + + +def finetune_qsenn(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule ,n_features, n_per_class): + for iteration_epoch in range(4): + print(f"Starting iteration epoch {iteration_epoch}") + this_log_dir = log_dir / f"iteration_epoch_{iteration_epoch}" + this_log_dir.mkdir(parents=True, exist_ok=True) + feature_sel, sparse_layer,bias_sparse, current_mean, current_std = compute_qsenn_feature_selection_and_assignment(model, train_loader, + test_loader, + this_log_dir, n_classes, seed, n_features, n_per_class) + model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) + if os.path.exists(this_log_dir / "trained_model.pth"): + model.load_state_dict(torch.load(this_log_dir / "trained_model.pth")) + _ = optimization_schedule.get_params() # count up, to have get correct lr + continue + + model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader) + torch.save(model.state_dict(), this_log_dir / "trained_model.pth") + print(f"Finished iteration epoch {iteration_epoch}") + return model + + + + diff --git a/finetuning/sldd.py b/finetuning/sldd.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8ac0034b14cbbf460f0bf59e25dfd8188ee94b --- /dev/null +++ b/finetuning/sldd.py @@ -0,0 +1,22 @@ +import numpy as np +import torch + +from FeatureDiversityLoss import FeatureDiversityLoss +from finetuning.utils import train_n_epochs +from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment +from sparsification.sldd import compute_sldd_feature_selection_and_assignment +from train import train, test +from training.optim import get_optimizer + + + + +def finetune_sldd(model, train_loader, test_loader, log_dir, n_classes, seed, beta, optimization_schedule,n_per_class, n_features, ): + feature_sel, weight, bias, mean, std = compute_sldd_feature_selection_and_assignment(model, train_loader, + test_loader, + log_dir, n_classes, seed,n_per_class, n_features) + model.set_model_sldd(feature_sel, weight, mean, std, bias) + model = train_n_epochs( model, beta, optimization_schedule, train_loader, test_loader) + return model + + diff --git a/finetuning/utils.py b/finetuning/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af751e2094c5ba6b2f83adadb5059f692329db37 --- /dev/null +++ b/finetuning/utils.py @@ -0,0 +1,14 @@ +from FeatureDiversityLoss import FeatureDiversityLoss +from train import train, test +from training.optim import get_optimizer + + +def train_n_epochs(model, beta,optimization_schedule, train_loader, test_loader): + optimizer, schedule, epochs = get_optimizer(model, optimization_schedule) + fdl = FeatureDiversityLoss(beta, model.linear) + for epoch in range(epochs): + model = train(model, train_loader, optimizer, fdl, epoch) + schedule.step() + if epoch % 5 == 0 or epoch+1 == epochs: + test(model, test_loader, epoch) + return model \ No newline at end of file diff --git a/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg b/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6450f6174bdd37cace75c6b32a029bcfa8761ed7 Binary files /dev/null and b/flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg differ diff --git a/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg b/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9f6063f1c6130694ddd53c0231b317abe9ef03b Binary files /dev/null and b/flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg differ diff --git a/flagged/log.csv b/flagged/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..5af3d3f8c5830b52178c12538580c9cd038fd2e4 --- /dev/null +++ b/flagged/log.csv @@ -0,0 +1,3 @@ +input,output,flag,username,timestamp +flagged/input/1e670025e5206017965a/Western_Grebe_0090_36182.jpg,,,,2024-10-21 12:37:51.541901 +flagged/input/6a11e385290e9006bb0a/Black_Footed_Albatross_0003_796136.jpg,"[{""image"": ""flagged/output/e2f704607c002e0c557d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1b4541c3e93f034d746d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f8727dcfa3c59de0d873/image.webp"", ""caption"": null}, {""image"": ""flagged/output/c4b75e9fbc946f6ead6d/image.webp"", ""caption"": null}, {""image"": ""flagged/output/5b5ad2dd997a635f4917/image.webp"", ""caption"": null}, {""image"": ""flagged/output/b066004e4a0114aa705b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/036072cdcc620de8cb65/image.webp"", ""caption"": null}, {""image"": ""flagged/output/218135cb251eb6cd0b2c/image.webp"", ""caption"": null}, {""image"": ""flagged/output/2a0671ba5ac1aa3bd2b9/image.webp"", ""caption"": null}, {""image"": ""flagged/output/595953adce3a654bbd33/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f333c69915509927b2ff/image.webp"", ""caption"": null}, {""image"": ""flagged/output/a966f50f23644e5046e8/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1a8a9e53fd4990fe5231/image.webp"", ""caption"": null}, {""image"": ""flagged/output/d7bc2f0eb8d70a562542/image.webp"", ""caption"": null}, {""image"": ""flagged/output/53fd53c5eab644d30338/image.webp"", ""caption"": null}, {""image"": ""flagged/output/ddf6b8ddc855838cc3b5/image.webp"", ""caption"": null}, {""image"": ""flagged/output/41a99b70366ac01533b4/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1b4ae8362917e14cb7a7/image.webp"", ""caption"": null}, {""image"": ""flagged/output/b321456290561eacf170/image.webp"", ""caption"": null}, {""image"": ""flagged/output/42d34c69c2384bda376b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/35d0e9ae554c0b863ef3/image.webp"", ""caption"": null}, {""image"": ""flagged/output/799f55238c434907570f/image.webp"", ""caption"": null}, {""image"": ""flagged/output/db82081afaabf2fb505b/image.webp"", ""caption"": null}, {""image"": ""flagged/output/fff73f12467314dce395/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1bd17ff3896c5045b453/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e31f93405e1526fe3e55/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e9c9ff1da0805da0c0d8/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e6ef5ba2d6c65b3c1d21/image.webp"", ""caption"": null}, {""image"": ""flagged/output/f763a51fb4a6d8a13313/image.webp"", ""caption"": null}, {""image"": ""flagged/output/7bdb4562631122e4ced7/image.webp"", ""caption"": null}, {""image"": ""flagged/output/9f7495b7c7648ecb1a10/image.webp"", ""caption"": null}, {""image"": ""flagged/output/ecbe75612f5db6cc7370/image.webp"", ""caption"": null}, {""image"": ""flagged/output/31f824d9522d30106a44/image.webp"", ""caption"": null}, {""image"": ""flagged/output/e06b9103e0bf90cd398a/image.webp"", ""caption"": null}, {""image"": ""flagged/output/1441b4f37340c2afa3d0/image.webp"", ""caption"": null}]",,,2024-10-21 23:01:32.158338 diff --git a/flagged/output/036072cdcc620de8cb65/image.webp b/flagged/output/036072cdcc620de8cb65/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..4e7a831b3e63d41cf7dc53178e8f19231f456648 Binary files /dev/null and b/flagged/output/036072cdcc620de8cb65/image.webp differ diff --git a/flagged/output/1441b4f37340c2afa3d0/image.webp b/flagged/output/1441b4f37340c2afa3d0/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..46e20fcfccd3a763d3eae21a0fda7d2908c6f53b Binary files /dev/null and b/flagged/output/1441b4f37340c2afa3d0/image.webp differ diff --git a/flagged/output/1a8a9e53fd4990fe5231/image.webp b/flagged/output/1a8a9e53fd4990fe5231/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..30452bb9f913012c3a787e78f5af2a657bfc4a82 Binary files /dev/null and b/flagged/output/1a8a9e53fd4990fe5231/image.webp differ diff --git a/flagged/output/1b4541c3e93f034d746d/image.webp b/flagged/output/1b4541c3e93f034d746d/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..41fb284622b8bf0e85dac87a497a4942011579f2 Binary files /dev/null and b/flagged/output/1b4541c3e93f034d746d/image.webp differ diff --git a/flagged/output/1b4ae8362917e14cb7a7/image.webp b/flagged/output/1b4ae8362917e14cb7a7/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..3566fc9c4f4f8bc2d8be57ffbaf1fb0b84f6fed8 Binary files /dev/null and b/flagged/output/1b4ae8362917e14cb7a7/image.webp differ diff --git a/flagged/output/1bd17ff3896c5045b453/image.webp b/flagged/output/1bd17ff3896c5045b453/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..a2e8a49694c6a233177b8757e916860ec2c217cb Binary files /dev/null and b/flagged/output/1bd17ff3896c5045b453/image.webp differ diff --git a/flagged/output/218135cb251eb6cd0b2c/image.webp b/flagged/output/218135cb251eb6cd0b2c/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..986c085197db498a852f013f503db78b64b4f7c5 Binary files /dev/null and b/flagged/output/218135cb251eb6cd0b2c/image.webp differ diff --git a/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp b/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..5e9a54c48df42fa656f0bced1d9580acd75cf7ba Binary files /dev/null and b/flagged/output/2a0671ba5ac1aa3bd2b9/image.webp differ diff --git a/flagged/output/31f824d9522d30106a44/image.webp b/flagged/output/31f824d9522d30106a44/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..c91c7a09d9b8611a1430afa699da601d7d0efe21 Binary files /dev/null and b/flagged/output/31f824d9522d30106a44/image.webp differ diff --git a/flagged/output/35d0e9ae554c0b863ef3/image.webp b/flagged/output/35d0e9ae554c0b863ef3/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..2473cae43807f063aa4d3e568e06e17e4b569920 Binary files /dev/null and b/flagged/output/35d0e9ae554c0b863ef3/image.webp differ diff --git a/flagged/output/41a99b70366ac01533b4/image.webp b/flagged/output/41a99b70366ac01533b4/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..4121b433b87b66dd3fbb58722c67818906c67411 Binary files /dev/null and b/flagged/output/41a99b70366ac01533b4/image.webp differ diff --git a/flagged/output/42d34c69c2384bda376b/image.webp b/flagged/output/42d34c69c2384bda376b/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..7618c903c18dc2451d25e1f32656f4caf9fe6ddb Binary files /dev/null and b/flagged/output/42d34c69c2384bda376b/image.webp differ diff --git a/flagged/output/53fd53c5eab644d30338/image.webp b/flagged/output/53fd53c5eab644d30338/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..2abbb10f1dbe66b93e37422b2470a0f071dea7cf Binary files /dev/null and b/flagged/output/53fd53c5eab644d30338/image.webp differ diff --git a/flagged/output/595953adce3a654bbd33/image.webp b/flagged/output/595953adce3a654bbd33/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..881add82b80c09007934e0467acd081e1b5fd7ac Binary files /dev/null and b/flagged/output/595953adce3a654bbd33/image.webp differ diff --git a/flagged/output/5b5ad2dd997a635f4917/image.webp b/flagged/output/5b5ad2dd997a635f4917/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..e57d0e88fbfc7af54aeb69995fe44af657c0d8dd Binary files /dev/null and b/flagged/output/5b5ad2dd997a635f4917/image.webp differ diff --git a/flagged/output/799f55238c434907570f/image.webp b/flagged/output/799f55238c434907570f/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..85d8a8fe108f97bec9684ccb2c614db43035d88e Binary files /dev/null and b/flagged/output/799f55238c434907570f/image.webp differ diff --git a/flagged/output/7bdb4562631122e4ced7/image.webp b/flagged/output/7bdb4562631122e4ced7/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..0b046a9a2ca40f025b7cc77df1b4c4f0613a7659 Binary files /dev/null and b/flagged/output/7bdb4562631122e4ced7/image.webp differ diff --git a/flagged/output/9f7495b7c7648ecb1a10/image.webp b/flagged/output/9f7495b7c7648ecb1a10/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..bc21593a6869f1cc00f78de4dd9ebf912d18d795 Binary files /dev/null and b/flagged/output/9f7495b7c7648ecb1a10/image.webp differ diff --git a/flagged/output/a966f50f23644e5046e8/image.webp b/flagged/output/a966f50f23644e5046e8/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..ffb81c67f03b993798f710e71e65b0f43cd151ca Binary files /dev/null and b/flagged/output/a966f50f23644e5046e8/image.webp differ diff --git a/flagged/output/b066004e4a0114aa705b/image.webp b/flagged/output/b066004e4a0114aa705b/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..b861d88dda4c0c7b783a87abaabc29f94dc943b2 Binary files /dev/null and b/flagged/output/b066004e4a0114aa705b/image.webp differ diff --git a/flagged/output/b321456290561eacf170/image.webp b/flagged/output/b321456290561eacf170/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..a10280498c93346105b7da59ba0808494004024c Binary files /dev/null and b/flagged/output/b321456290561eacf170/image.webp differ diff --git a/flagged/output/c4b75e9fbc946f6ead6d/image.webp b/flagged/output/c4b75e9fbc946f6ead6d/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..106535e80842768f14da47245baa981cabeea71b Binary files /dev/null and b/flagged/output/c4b75e9fbc946f6ead6d/image.webp differ diff --git a/flagged/output/d7bc2f0eb8d70a562542/image.webp b/flagged/output/d7bc2f0eb8d70a562542/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..686f7aac20f6a5449db99a97dc43c01dfdd99551 Binary files /dev/null and b/flagged/output/d7bc2f0eb8d70a562542/image.webp differ diff --git a/flagged/output/db82081afaabf2fb505b/image.webp b/flagged/output/db82081afaabf2fb505b/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..af93f3eacee5ef0a4903995aaf8d2e2e5921976d Binary files /dev/null and b/flagged/output/db82081afaabf2fb505b/image.webp differ diff --git a/flagged/output/ddf6b8ddc855838cc3b5/image.webp b/flagged/output/ddf6b8ddc855838cc3b5/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..f5d206e37f97d1f45b91aceb45df630ca9fae223 Binary files /dev/null and b/flagged/output/ddf6b8ddc855838cc3b5/image.webp differ diff --git a/flagged/output/e06b9103e0bf90cd398a/image.webp b/flagged/output/e06b9103e0bf90cd398a/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..8b4510986e28e5a132f4cc197ce9063b072b113b Binary files /dev/null and b/flagged/output/e06b9103e0bf90cd398a/image.webp differ diff --git a/flagged/output/e2f704607c002e0c557d/image.webp b/flagged/output/e2f704607c002e0c557d/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..a7b2dafc246639799cfabd97306b3c4ba426cba6 Binary files /dev/null and b/flagged/output/e2f704607c002e0c557d/image.webp differ diff --git a/flagged/output/e31f93405e1526fe3e55/image.webp b/flagged/output/e31f93405e1526fe3e55/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..955afaa8ead0e0bc67be7722bbc791dbfe4f35be Binary files /dev/null and b/flagged/output/e31f93405e1526fe3e55/image.webp differ diff --git a/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp b/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..4e0429c00817f88878d5cfc460039b8ed169c74c Binary files /dev/null and b/flagged/output/e6ef5ba2d6c65b3c1d21/image.webp differ diff --git a/flagged/output/e9c9ff1da0805da0c0d8/image.webp b/flagged/output/e9c9ff1da0805da0c0d8/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..d73c97deb0388bfb5423ad36e686e5e3ca44ce8d Binary files /dev/null and b/flagged/output/e9c9ff1da0805da0c0d8/image.webp differ diff --git a/flagged/output/ecbe75612f5db6cc7370/image.webp b/flagged/output/ecbe75612f5db6cc7370/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..7608061bb4d13a3e87696620071a61112463dea9 Binary files /dev/null and b/flagged/output/ecbe75612f5db6cc7370/image.webp differ diff --git a/flagged/output/f333c69915509927b2ff/image.webp b/flagged/output/f333c69915509927b2ff/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..766635e3fd9996a671a9fd9e09bd37901f37a20e Binary files /dev/null and b/flagged/output/f333c69915509927b2ff/image.webp differ diff --git a/flagged/output/f763a51fb4a6d8a13313/image.webp b/flagged/output/f763a51fb4a6d8a13313/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..65b0c4ddd52dac9f92f898bc64ed67c18722c6ac Binary files /dev/null and b/flagged/output/f763a51fb4a6d8a13313/image.webp differ diff --git a/flagged/output/f8727dcfa3c59de0d873/image.webp b/flagged/output/f8727dcfa3c59de0d873/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..175792b0035f9df9db88d62f56d40356d8afbbfe Binary files /dev/null and b/flagged/output/f8727dcfa3c59de0d873/image.webp differ diff --git a/flagged/output/fff73f12467314dce395/image.webp b/flagged/output/fff73f12467314dce395/image.webp new file mode 100644 index 0000000000000000000000000000000000000000..0b51dcf8c9e672660529ad1adea9afe37e5a4f08 Binary files /dev/null and b/flagged/output/fff73f12467314dce395/image.webp differ diff --git a/get_data.py b/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e9414c933a64f1124f4eda6ec0faa8cd8ed2ee --- /dev/null +++ b/get_data.py @@ -0,0 +1,119 @@ +from pathlib import Path + +import torch +import torchvision +from torchvision.transforms import transforms, TrivialAugmentWide + +from configs.dataset_params import normalize_params +from dataset_classes.cub200 import CUB200Class +from dataset_classes.stanfordcars import StanfordCarsClass +from dataset_classes.travelingbirds import TravelingBirds + + +def get_data(dataset, crop = True, img_size=448): + batchsize = 16 + if dataset == "CUB2011": + train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"]) + test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"]) + train_dataset = CUB200Class(True, train_transform, crop) + test_dataset = CUB200Class(False, test_transform, crop) + elif dataset == "TravelingBirds": + train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"]) + test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"]) + train_dataset = TravelingBirds(True, train_transform, crop) + test_dataset = TravelingBirds(False, test_transform, crop) + + elif dataset == "StanfordCars": + train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"]) + test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"]) + train_dataset = StanfordCarsClass(True, train_transform) + test_dataset = StanfordCarsClass(False, test_transform) + elif dataset == "FGVCAircraft": + raise NotImplementedError + + elif dataset == "ImageNet": + # Defaults from the robustness package + if img_size != 224: + raise NotImplementedError("ImageNet is setup to only work with 224x224 images") + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter( + brightness=0.1, + contrast=0.1, + saturation=0.1 + ), + transforms.ToTensor(), + Lighting(0.05, IMAGENET_PCA['eigval'], + IMAGENET_PCA['eigvec']) + ]) + """ + Standard training data augmentation for ImageNet-scale datasets: Random crop, + Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc) + """ + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + ]) + imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet" + train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform) + test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform) + batchsize = 64 + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8) + return train_loader, test_loader + +def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize): + augmentation = [] + if random_center_crop: + augmentation.append(transforms.Resize(size)) + else: + augmentation.append(transforms.Resize((size, size))) + if training: + if random_center_crop: + augmentation.append(transforms.RandomCrop(size, padding=4)) + else: + if random_center_crop: + augmentation.append(transforms.CenterCrop(size)) + if training: + if hflip: + augmentation.append(transforms.RandomHorizontalFlip()) + if jitter: + augmentation.append(transforms.ColorJitter(jitter, jitter, jitter)) + if trivialAug: + augmentation.append(TrivialAugmentWide()) + augmentation.append(transforms.ToTensor()) + augmentation.append(transforms.Normalize(**normalize)) + return transforms.Compose(augmentation) + +class Lighting(object): + """ + Lighting noise (see https://git.io/fhBOc) + """ + + def __init__(self, alphastd, eigval, eigvec): + self.alphastd = alphastd + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0: + return img + + alpha = img.new().resize_(3).normal_(0, self.alphastd) + rgb = self.eigvec.type_as(img).clone() \ + .mul(alpha.view(1, 3).expand(3, 3)) \ + .mul(self.eigval.view(1, 3).expand(3, 3)) \ + .sum(1).squeeze() + + return img.add(rgb.view(3, 1, 1).expand_as(img)) +IMAGENET_PCA = { + 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), + 'eigvec': torch.Tensor([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) +} diff --git a/load_model.py b/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..13e962a7497da9a895feed12708fe9c55a24dbdd --- /dev/null +++ b/load_model.py @@ -0,0 +1,51 @@ +from argparse import ArgumentParser +from pathlib import Path + +import torch + +from architectures.model_mapping import get_model +from configs.dataset_params import dataset_constants +from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics +from get_data import get_data + +def extract_sel_mean_std_bias_assignemnt(state_dict): + feature_sel = state_dict["linear.selection"] + #feature_sel = selection + weight_at_selection = state_dict["linear.layer.weight"] + mean = state_dict["linear.mean"] + std = state_dict["linear.std"] + bias = state_dict["linear.layer.bias"] + return feature_sel, weight_at_selection, mean, std, bias + + +def eval_model(dataset, arch,seed=123456, model_type="qsenn",crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): + n_classes = dataset_constants[dataset]["num_classes"] + train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size) + model = get_model(arch, n_classes, reduced_strides) + if folder is None: + folder = Path.home() / f"tmp/{arch}/{dataset}/{seed}/" + print(folder) + state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") + selection= torch.load(folder / f"SlDD_Selection_50.pt") + state_dict['linear.selection']=selection + print(state_dict.keys()) + feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) + model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) + model.load_state_dict(state_dict) + print(model) + metrics_finetuned = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader) + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"]) + parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"]) + parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"]) + parser.add_argument('--seed', default=123456, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629 + parser.add_argument('--cropGT', default=False, type=bool, + help='Whether to crop CUB/TravelingBirds based on GT Boundaries') + parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567 + parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class') + parser.add_argument('--img_size', default=448, type=int, help='Image size') + parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets') + args = parser.parse_args() + eval_model(args.dataset, args.arch, args.seed, args.model_type,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..5a340862967c2c1d8befc1eff79bd00122223f93 --- /dev/null +++ b/main.py @@ -0,0 +1,79 @@ +import os +from argparse import ArgumentParser +from pathlib import Path + +import numpy as np +import torch +from tqdm import trange + +from FeatureDiversityLoss import FeatureDiversityLoss +from architectures.model_mapping import get_model +from configs.architecture_params import architecture_params +from configs.dataset_params import dataset_constants +from evaluation.qsenn_metrics import eval_model_on_all_qsenn_metrics +from finetuning.map_function import finetune +from get_data import get_data +from saving.logging import Tee +from saving.utils import json_save +from train import train, test +from training.optim import get_optimizer, get_scheduler_for_model + + +def main(dataset, arch,seed=None, model_type="qsenn", do_dense=True,crop = True, n_features = 50, n_per_class=5, img_size=448, reduced_strides=False): + # create random seed, if seed is None + if seed is None: + seed = np.random.randint(0, 1000000) + np.random.seed(seed) + torch.manual_seed(seed) + dataset_key = dataset + if crop: + assert dataset in ["CUB2011","TravelingBirds"] + dataset_key += "_crop" + log_dir = Path.home()/f"tmp/{arch}/{dataset_key}/{seed}/" + log_dir.mkdir(parents=True, exist_ok=True) + tee = Tee(log_dir / "log.txt") # save log to file + n_classes = dataset_constants[dataset]["num_classes"] + train_loader, test_loader = get_data(dataset, crop=crop, img_size=img_size) + model = get_model(arch, n_classes, reduced_strides) + fdl = FeatureDiversityLoss(architecture_params[arch]["beta"], model.linear) + OptimizationSchedule = get_scheduler_for_model(model_type, dataset) + optimizer, schedule, dense_epochs =get_optimizer(model, OptimizationSchedule) + if not os.path.exists(log_dir / "Trained_DenseModel.pth"): + if do_dense: + for epoch in trange(dense_epochs): + model = train(model, train_loader, optimizer, fdl, epoch) + schedule.step() + if epoch % 5 == 0: + test(model, test_loader,epoch) + else: + print("Using pretrained model, only makes sense for ImageNet") + torch.save(model.state_dict(), os.path.join(log_dir, f"Trained_DenseModel.pth")) + else: + model.load_state_dict(torch.load(log_dir / "Trained_DenseModel.pth")) + if not os.path.exists( log_dir/f"Results_DenseModel.json"): + metrics_dense = eval_model_on_all_qsenn_metrics(model, test_loader, train_loader) + json_save(os.path.join(log_dir, f"Results_DenseModel.json"), metrics_dense) + final_model = finetune(model_type, model, train_loader, test_loader, log_dir, n_classes, seed, architecture_params[arch]["beta"], OptimizationSchedule, n_per_class, n_features) + torch.save(final_model.state_dict(), os.path.join(log_dir,f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")) + metrics_finetuned = eval_model_on_all_qsenn_metrics(final_model, test_loader, train_loader) + json_save(os.path.join(log_dir, f"Results_{model_type}_{n_features}_{n_per_class}_FinetunedModel.json"), metrics_finetuned) + print("Done") + pass + + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--dataset', default="CUB2011", type=str, help='dataset name', choices=["CUB2011", "ImageNet", "TravelingBirds", "StanfordCars"]) + parser.add_argument('--arch', default="resnet50", type=str, help='Backbone Feature Extractor', choices=["resnet50", "resnet18"]) + parser.add_argument('--model_type', default="qsenn", type=str, help='Type of Model', choices=["qsenn", "sldd"]) + parser.add_argument('--seed', default=None, type=int, help='seed, used for naming the folder and random processes. Could be useful to set to have multiple finetune runs (e.g. Q-SENN and SLDD) on the same dense model') # 769567, 552629 + parser.add_argument('--do_dense', default=True, type=bool, help='whether to train dense model. Should be true for all datasets except (maybe) ImageNet') + parser.add_argument('--cropGT', default=False, type=bool, + help='Whether to crop CUB/TravelingBirds based on GT Boundaries') + parser.add_argument('--n_features', default=50, type=int, help='How many features to select') #769567 + parser.add_argument('--n_per_class', default=5, type=int, help='How many features to assign to each class') + parser.add_argument('--img_size', default=448, type=int, help='Image size') + parser.add_argument('--reduced_strides', default=False, type=bool, help='Whether to use reduced strides for resnets') + args = parser.parse_args() + main(args.dataset, args.arch, args.seed, args.model_type, args.do_dense,args.cropGT, args.n_features, args.n_per_class, args.img_size, args.reduced_strides) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f8641a287620411061756d1e997027170cd2a33 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch +torchvision +opencv-python \ No newline at end of file diff --git a/saving/logging.py b/saving/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..377e31b33e06865bb588bfec32678c947e5c3bb3 --- /dev/null +++ b/saving/logging.py @@ -0,0 +1,27 @@ +import sys + + +class Tee(object): + def __init__(self, name, file_only=False): + self.file = open(name, "a") + self.stdout = sys.stdout + self.stderr = sys.stderr + sys.stdout = self + sys.stderr = self + self.file_only = file_only + + def __del__(self): + sys.stdout = self.stdout + sys.stderr = self.stderr + self.file.close() + + def write(self, data): + self.file.write(data) + if not self.file_only: + self.stdout.write(data) + self.flush() + + def flush(self): + self.file.flush() + + diff --git a/saving/utils.py b/saving/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf35c3c5129e84ab6f13f85c4c04fba4f7e33b4 --- /dev/null +++ b/saving/utils.py @@ -0,0 +1,6 @@ +import json + + +def json_save(filename, data): + with open(filename, "w") as f: + json.dump(data, f,indent=4) \ No newline at end of file diff --git a/sparsification/FeatureSelection.py b/sparsification/FeatureSelection.py new file mode 100644 index 0000000000000000000000000000000000000000..885a7bc9fab8842e9eece0f07690636b1d623233 --- /dev/null +++ b/sparsification/FeatureSelection.py @@ -0,0 +1,473 @@ +from argparse import ArgumentParser +import logging +import math +import os.path +import sys +import time +import warnings + +import numpy as np +import torch +import torch.nn.functional as F +from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader +from torch import nn + +import torch as ch + +from sparsification.utils import safe_zip + +# TODO checkout this change: Marks changes to the group version of glmsaga + +""" +This would need glm_saga to run +usage to select 50 features with parameters as in paper: +metadata contains information about the precomputed train features in feature_loaders +args contains the default arguments for glm-saga, as described at the bottom +def get_glm_to_zero(feature_loaders, metadata, args, num_classes, device, train_ds, Ntotal): + num_features = metadata["X"]["num_features"][0] + fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8, + 50, + True,0.1, + lookback=3, tol=1e-4, + epsilon=1,) + to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) + return to_drop + +to_drop is then used to remove the features from the downstream fitting and finetuning. +""" + + +class FeatureSelectionFitting: + def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac,out_dir, lookback=None, tol=None, + epsilon=None): + """ + This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks) + The function extended_mask_max covers the changed operator, + Args: + n_features: + n_classes: + args: default args for glmsaga + selalpha: alpha for elastic net + nKeep: target number features + lam_fac: discount factor for lambda + parameters of glmsaga + lookback: + tol: + epsilon: + """ + self.selected_features = torch.zeros(n_features, dtype=torch.bool) + self.num_features = n_features + self.selalpha = selalpha + self.lam_Fac = lam_fac + self.out_dir = out_dir + self.n_classes = n_classes + self.nKeep = nKeep + self.args = self.extend_args(args, lookback, tol, epsilon) + + # Extended Proximal Operator for Feature Selection + def extended_mask_max(self, greater_to_keep, thresh): + prev = greater_to_keep[self.selected_features] + greater_to_keep[self.selected_features] = torch.min(greater_to_keep) + max_entry = torch.argmax(greater_to_keep) + greater_to_keep[self.selected_features] = prev + mask = torch.zeros_like(greater_to_keep) + mask[max_entry] = 1 + final_mask = (greater_to_keep > thresh) + final_mask = final_mask * mask + allowed_to_keep = torch.logical_or(self.selected_features, final_mask) + return allowed_to_keep + + def extend_args(self, args, lookback, tol, epsilon): + for key, entry in safe_zip(["lookbehind", "tol", + "lr_decay_factor", ], [lookback, tol, epsilon]): + if entry is not None: + setattr(args, key, entry) + return args + + # Grouped L1 regularization + # proximal operator for f(weight) = lam * \|weight\|_2 + # where the 2-norm is taken columnwise + def group_threshold(self, weight, lam): + norm = weight.norm(p=2, dim=0) + 1e-6 + # print(ch.sum((norm > lam))) + return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam) + + # Elastic net regularization with group sparsity + # proximal operator for f(x) = alpha * \|x\|_1 + beta * \|x\|_2^2 + # where the 2-norm is taken columnwise + def group_threshold_with_shrinkage(self, x, alpha, beta): + y = self.group_threshold(x, alpha) + return y / (1 + beta) + + def threshold(self, weight_new, lr, lam): + alpha = self.selalpha + if alpha == 1: + # Pure L1 regularization + weight_new = self.group_threshold(weight_new, lr * lam * alpha) + else: + # Elastic net regularization + weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha, + lr * lam * (1 - alpha)) + return weight_new + + # Train an elastic GLM with proximal SAGA + # Since SAGA stores a scalar for each example-class pair, either pass + # the number of examples and number of classes or calculate it with an + # initial pass over the loaders + def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None, + state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4, + preprocess=None, lookbehind=None, family='multinomial', logger=None): + if logger is None: + logger = print + with ch.no_grad(): + weight, bias = list(linear.parameters()) + if table_device is None: + table_device = weight.device + + # get total number of examples and initialize scalars + # for computing the gradients + if n_ex is None: + n_ex = sum(tensors[0].size(0) for tensors in loader) + if n_classes is None: + if family == 'multinomial': + n_classes = max(tensors[1].max().item() for tensors in loader) + 1 + elif family == 'gaussian': + for batch in loader: + y = batch[1] + break + n_classes = y.size(1) + + # Storage for scalar gradients and averages + if state is None: + a_table = ch.zeros(n_ex, n_classes).to(table_device) + w_grad_avg = ch.zeros_like(weight).to(weight.device) + b_grad_avg = ch.zeros_like(bias).to(weight.device) + else: + a_table = state["a_table"].to(table_device) + w_grad_avg = state["w_grad_avg"].to(weight.device) + b_grad_avg = state["b_grad_avg"].to(weight.device) + + obj_history = [] + obj_best = None + nni = 0 + for t in range(nepochs): + total_loss = 0 + for n_batch, batch in enumerate(loader): + if len(batch) == 3: + X, y, idx = batch + w = None + elif len(batch) == 4: + X, y, w, idx = batch + else: + raise ValueError( + f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}") + + if preprocess is not None: + device = get_device(preprocess) + with ch.no_grad(): + X = preprocess(X.to(device)) + X = X.to(weight.device) + out = linear(X) + + # split gradient on only the cross entropy term + # for efficient storage of gradient information + if family == 'multinomial': + if w is None: + loss = F.cross_entropy(out, y.to(weight.device), reduction='mean') + else: + loss = F.cross_entropy(out, y.to(weight.device), reduction='none') + loss = (loss * w).mean() + I = ch.eye(linear.weight.size(0)) + target = I[y].to(weight.device) # change to OHE + + # Calculate new scalar gradient + logits = F.softmax(linear(X)) + elif family == 'gaussian': + if w is None: + loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean') + else: + loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none') + loss = (loss * (w.unsqueeze(1))).mean() + target = y + + # Calculate new scalar gradient + logits = linear(X) + else: + raise ValueError(f"Unknown family: {family}") + total_loss += loss.item() * X.size(0) + + # BS x NUM_CLASSES + a = logits - target + if w is not None: + a = a * w.unsqueeze(1) + a_prev = a_table[idx].to(weight.device) + + # weight parameter + w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0) + w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0) + w_saga = w_grad - w_grad_prev + w_grad_avg + weight_new = weight - lr * w_saga + weight_new = self.threshold(weight_new, lr, lam) + # bias parameter + b_grad = a.mean(0) + b_grad_prev = a_prev.mean(0) + b_saga = b_grad - b_grad_prev + b_grad_avg + bias_new = bias - lr * b_saga + + # update table and averages + a_table[idx] = a.to(table_device) + w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex) + b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex) + + if lookbehind is None: + dw = (weight_new - weight).norm(p=2) + db = (bias_new - bias).norm(p=2) + criteria = ch.sqrt(dw ** 2 + db ** 2) + + if criteria.item() <= tol: + return { + "a_table": a_table.cpu(), + "w_grad_avg": w_grad_avg.cpu(), + "b_grad_avg": b_grad_avg.cpu() + } + + weight.data = weight_new + bias.data = bias_new + + saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * ( + weight ** 2).sum() + + # save amount of improvement + obj_history.append(saga_obj.item()) + if obj_best is None or saga_obj.item() + tol < obj_best: + obj_best = saga_obj.item() + nni = 0 + else: + nni += 1 + + # Stop if no progress for lookbehind iterationsd:]) + criteria = lookbehind is not None and (nni >= lookbehind) + + nnz = (weight.abs() > 1e-5).sum().item() + total = weight.numel() + if verbose and (t % verbose) == 0: + if lookbehind is None: + logger( + f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}") + else: + logger( + f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}") + + if lookbehind is not None and criteria: + logger( + f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]") + return { + "a_table": a_table.cpu(), + "w_grad_avg": w_grad_avg.cpu(), + "b_grad_avg": b_grad_avg.cpu() + } + + logger(f"did not converge at {nepochs} iterations (criteria {criteria})") + return { + "a_table": a_table.cpu(), + "w_grad_avg": w_grad_avg.cpu(), + "b_grad_avg": b_grad_avg.cpu() + } + + def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries, + table_device=None, preprocess=None, group=False, + verbose=None, state=None, n_ex=None, n_classes=None, + tol=1e-4, epsilon=0.001, k=100, checkpoint=None, + do_zero=True, lr_decay_factor=1, metadata=None, + val_loader=None, test_loader=None, lookbehind=None, + family='multinomial', encoder=None, tot_tries=1): + if encoder is not None: + warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning) + preprocess = encoder + device = get_device(linear) + checkpoint = self.out_dir + if preprocess is not None and (device != get_device(preprocess)): + raise ValueError( + f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})") + + if metadata is not None: + if n_ex is None: + n_ex = metadata['X']['num_examples'] + if n_classes is None: + n_classes = metadata['y']['num_classes'] + lam_fac = (1 + (tries - 1) / tot_tries) + print("Using lam_fac ", lam_fac) + max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata, + family=family) / max( + 0.001, alpha) * lam_fac + group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata, + family=family) / max( + 0.001, alpha) * lam_fac + min_lam = epsilon * max_lam + group_min_lam = epsilon * group_lam + # logspace is base 10 but log is base e so use log10 + lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k) + lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k) + found = False + if do_zero: + lams = ch.cat([lams, lams.new_zeros(1)]) + lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]]) + + path = [] + best_val_loss = float('inf') + + if checkpoint is not None: + os.makedirs(checkpoint, exist_ok=True) + + file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log')) + stdout_handler = logging.StreamHandler(sys.stdout) + handlers = [file_handler, stdout_handler] + + logging.basicConfig( + level=logging.DEBUG, + format='[%(asctime)s] %(levelname)s - %(message)s', + handlers=handlers + ) + logger = logging.getLogger('glm_saga').info + else: + logger = print + while self.selected_features.sum() < self.nKeep: # TODO checkout this change, one iteration per feature + n_feature_to_keep = self.selected_features.sum() + for i, (lam, lr) in enumerate(zip(lams, lrs)): + lam = lam * self.lam_Fac + start_time = time.time() + self.selected_features = self.selected_features.to(device) + state = self.train_saga(linear, loader, lr, nepochs, lam, alpha, + table_device=table_device, preprocess=preprocess, group=group, verbose=verbose, + state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind, + family=family, logger=logger) + + with ch.no_grad(): + loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess, + family=family) + loss, acc = loss.item(), acc.item() + + loss_val, acc_val = -1, -1 + if val_loader: + loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha, + preprocess=preprocess, + family=family) + loss_val, acc_val = loss_val.item(), acc_val.item() + + loss_test, acc_test = -1, -1 + if test_loader: + loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha, + preprocess=preprocess, family=family) + loss_test, acc_test = loss_test.item(), acc_test.item() + + params = { + "lam": lam, + "lr": lr, + "alpha": alpha, + "time": time.time() - start_time, + "loss": loss, + "metrics": { + "loss_tr": loss, + "acc_tr": acc, + "loss_val": loss_val, + "acc_val": acc_val, + "loss_test": loss_test, + "acc_test": acc_test, + }, + "weight": linear.weight.detach().cpu().clone(), + "bias": linear.bias.detach().cpu().clone() + + } + path.append(params) + if loss_val is not None and loss_val < best_val_loss: + best_val_loss = loss_val + best_params = params + found = True + nnz = (linear.weight.abs() > 1e-5).sum().item() + total = linear.weight.numel() + if family == 'multinomial': + logger( + f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") + elif family == 'gaussian': + logger( + f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}") + + if self.check_new_feature(linear.weight): # TODO checkout this change, canceling if new feature is used + if checkpoint is not None: + ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth")) + break + if found: + return { + 'path': path, + 'best': best_params, + 'state': state + } + else: + return False + + def check_new_feature(self, weight): + # TODO checkout this change, checking if new feature is used + copied_weight = torch.tensor(weight.cpu()) + used_features = torch.unique( + torch.nonzero(copied_weight)[:, 1]) + if len(used_features) > 0: + new_set = set(used_features.tolist()) + old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist()) + diff = new_set - old_set + if len(diff) > 0: + self.selected_features[used_features] = True + return True + return False + + def fit(self, feature_loaders, metadata, device): + # TODO checkout this change, glm saga code slightly adapted to return to_drop + print("Initializing linear model...") + linear = nn.Linear(self.num_features, self.n_classes).to(device) + for p in [linear.weight, linear.bias]: + p.data.zero_() + + print("Preparing normalization preprocess and indexed dataloader") + preprocess = NormalizedRepresentation(feature_loaders['train'], + metadata=metadata, + device=linear.weight.device) + + print("Calculating the regularization path") + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + selected_features = self.glm_saga(linear, + feature_loaders['train'], + self.args.lr, + self.args.max_epochs, + self.selalpha, 0, 1, + val_loader=feature_loaders['val'], + test_loader=feature_loaders['test'], + n_classes=self.n_classes, + verbose=self.args.verbose, + tol=self.args.tol, + lookbehind=self.args.lookbehind, + lr_decay_factor=self.args.lr_decay_factor, + group=True, + epsilon=self.args.lam_factor, + metadata=metadata, + preprocess=preprocess, tot_tries=1) + to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0] + test_acc = selected_features["path"][-1]["metrics"]["acc_test"] + torch.set_grad_enabled(True) + return to_drop, test_acc + + +class NormalizedRepresentation(ch.nn.Module): + def __init__(self, loader, metadata, device='cuda', tol=1e-5): + super(NormalizedRepresentation, self).__init__() + + assert metadata is not None + self.device = device + self.mu = metadata['X']['mean'] + self.sigma = ch.clamp(metadata['X']['std'], tol) + + def forward(self, X): + return (X - self.mu.to(self.device)) / self.sigma.to(self.device) + + + + diff --git a/sparsification/data_helpers.py b/sparsification/data_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..d48424564050c66238f9b731b433ca25d29d5a6b --- /dev/null +++ b/sparsification/data_helpers.py @@ -0,0 +1,16 @@ + +import torch + + +class NormalizedRepresentation(torch.nn.Module): + def __init__(self, loader, metadata, device='cuda', tol=1e-5): + super(NormalizedRepresentation, self).__init__() + + assert metadata is not None + self.device = device + self.mu = metadata['X']['mean'] + self.sigma = torch.clamp(metadata['X']['std'], tol) + + def forward(self, X): + return (X - self.mu.to(self.device)) / self.sigma.to(self.device) + diff --git a/sparsification/feature_helpers.py b/sparsification/feature_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..8c11867077be5ab067548f498dad17fb299fa162 --- /dev/null +++ b/sparsification/feature_helpers.py @@ -0,0 +1,378 @@ +import math +import os +import sys + +import torch.cuda + +import sparsification.utils + +sys.path.append('') +import numpy as np +import torch as ch +from torch.utils.data import Subset +from tqdm import tqdm + + + +# From glm_saga +def get_features_batch(batch, model, device='cuda'): + if not torch.cuda.is_available(): + device = "cpu" + ims, targets = batch + output, latents = model(ims.to(device), with_final_features=True ) + return latents, targets + + +def compute_features(loader, model, dataset_type, pooled_output, + batch_size, num_workers, + shuffle=False, device='cpu', n_epoch=1, + filename=None, chunk_threshold=20000, balance=False): + """Compute deep features for a given dataset using a modeln and returnss + them as a pytorch dataset and loader. + Args: + loader : Torch data loader + model: Torch model + dataset_type (str): One of vision or language + pooled_output (bool): Whether or not to pool outputs + (only relevant for some language models) + batch_size (int): Batch size for output loader + num_workers (int): Number of workers to use for output loader + shuffle (bool): Whether or not to shuffle output data loaoder + device (str): Device on which to keep the model + filename (str):Optional file to cache computed feature. Recommended + for large dataset_classes like ImageNet. + chunk_threshold (int): Size of shard while caching + balance (bool): Whether or not to balance output data loader + (only relevant for some language models) + Returns: + feature_dataset: Torch dataset with deep features + feature_loader: Torch data loader with deep features + """ + if torch.cuda.is_available(): + device = "cuda" + print("mem_get_info before", torch.cuda.mem_get_info()) + torch.cuda.empty_cache() + print("mem_get_info after", torch.cuda.mem_get_info()) + model = model.to(device) + if filename is None or not os.path.exists(os.path.join(filename, f'0_features.npy')): + model.eval() + all_latents, all_targets, all_images = [], [], [] + Nsamples, chunk_id = 0, 0 + for idx_epoch in range(n_epoch): + for batch_idx, batch in tqdm(enumerate(loader), total=len(loader)): + with ch.no_grad(): + latents, targets = get_features_batch(batch, model, + device=device) + if batch_idx == 0: + print("Latents shape", latents.shape) + Nsamples += latents.size(0) + + all_latents.append(latents.cpu()) + if len(targets.shape) > 1: + targets = targets[:, 0] + all_targets.append(targets.cpu()) + # all_images.append(batch[0]) + if filename is not None and Nsamples > chunk_threshold: + if not os.path.exists(filename): os.makedirs(filename) + np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy()) + np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy()) + + all_latents, all_targets, Nsamples = [], [], 0 + chunk_id += 1 + + if filename is not None and Nsamples > 0: + if not os.path.exists(filename): os.makedirs(filename) + np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy()) + np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy()) + # np.save(os.path.join(filename, f'{chunk_id}_images.npy'), ch.cat(all_images).numpy()) + feature_dataset = load_features(filename) if filename is not None else \ + ch.utils.data.TensorDataset(ch.cat(all_latents), ch.cat(all_targets)) + if balance: + feature_dataset = balance_dataset(feature_dataset) + + feature_loader = ch.utils.data.DataLoader(feature_dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle) + + return feature_dataset, feature_loader + + +def load_feature_loader(out_dir_feats, val_frac, batch_size, num_workers, random_seed): + feature_loaders = {} + for mode in ['train', 'test']: + print(f"For {mode} set...") + sink_path = f"{out_dir_feats}/features_{mode}" + metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" + feature_ds = load_features(sink_path) + feature_loader = ch.utils.data.DataLoader(feature_ds, + num_workers=num_workers, + batch_size=batch_size) + if mode == 'train': + metadata = calculate_metadata(feature_loader, + num_classes=2048, + filename=metadata_path) + split_datasets, split_loaders = split_dataset(feature_ds, + len(feature_ds), + val_frac=val_frac, + batch_size=batch_size, + num_workers=num_workers, + random_seed=random_seed, + shuffle=True) + feature_loaders.update({mm: sparsification.utils.add_index_to_dataloader(split_loaders[mi]) + for mi, mm in enumerate(['train', 'val'])}) + + else: + feature_loaders[mode] = feature_loader + return feature_loaders, metadata + + +def balance_dataset(dataset): + """Balances a given dataset to have the same number of samples/class. + Args: + dataset : Torch dataset + Returns: + Torch dataset with equal number of samples/class + """ + + print("Balancing dataset...") + n = len(dataset) + labels = ch.Tensor([dataset[i][1] for i in range(n)]).int() + n0 = sum(labels).item() + I_pos = labels == 1 + + idx = ch.arange(n) + idx_pos = idx[I_pos] + ch.manual_seed(0) + I = ch.randperm(n - n0)[:n0] + idx_neg = idx[~I_pos][I] + idx_bal = ch.cat([idx_pos, idx_neg], dim=0) + return Subset(dataset, idx_bal) + + +def load_metadata(feature_path): + return ch.load(os.path.join(feature_path, f'metadata_train.pth')) + + +def get_mean_std(feature_path): + metadata = load_metadata(feature_path) + return metadata["X"]["mean"], metadata["X"]["std"] + + +def load_features_dataset_mode(feature_path, mode='test', + num_workers=10, batch_size=128): + """Loads precomputed deep features corresponding to the + train/test set along with normalization statitic. + Args: + feature_path (str): Path to precomputed deep features + mode (str): One of train or tesst + num_workers (int): Number of workers to use for output loader + batch_size (int): Batch size for output loader + + Returns: + features (np.array): Recovered deep features + feature_mean: Mean of deep features + feature_std: Standard deviation of deep features + """ + feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}')) + feature_loader = ch.utils.data.DataLoader(feature_dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=False) + feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth')) + feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std'] + return feature_loader, feature_mean, feature_std + + +def load_joint_dataset(feature_path, mode='test', + num_workers=10, batch_size=128): + feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}')) + feature_loader = ch.utils.data.DataLoader(feature_dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=False) + features = [] + labels = [] + for _, (feature, label) in tqdm(enumerate(feature_loader), total=len(feature_loader)): + features.append(feature) + labels.append(label) + features = np.concatenate(features) + labels = np.concatenate(labels) + dataset = ch.utils.data.TensorDataset(torch.tensor(features), torch.tensor(labels)) + return dataset + + +def load_features_mode(feature_path, mode='test', + num_workers=10, batch_size=128): + """Loads precomputed deep features corresponding to the + train/test set along with normalization statitic. + Args: + feature_path (str): Path to precomputed deep features + mode (str): One of train or tesst + num_workers (int): Number of workers to use for output loader + batch_size (int): Batch size for output loader + + Returns: + features (np.array): Recovered deep features + feature_mean: Mean of deep features + feature_std: Standard deviation of deep features + """ + feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}')) + feature_loader = ch.utils.data.DataLoader(feature_dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=False) + + feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth')) + feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std'] + + features = [] + + for _, (feature, _) in tqdm(enumerate(feature_loader), total=len(feature_loader)): + features.append(feature) + + features = ch.cat(features).numpy() + return features, feature_mean, feature_std + + +def load_features(feature_path): + """Loads precomputed deep features. + Args: + feature_path (str): Path to precomputed deep features + + Returns: + Torch dataset with recovered deep features. + """ + if not os.path.exists(os.path.join(feature_path, f"0_features.npy")): + raise ValueError(f"The provided location {feature_path} does not contain any representation files") + + ds_list, chunk_id = [], 0 + while os.path.exists(os.path.join(feature_path, f"{chunk_id}_features.npy")): + features = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_features.npy"))).float() + labels = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_labels.npy"))).long() + ds_list.append(ch.utils.data.TensorDataset(features, labels)) + chunk_id += 1 + + print(f"==> loaded {chunk_id} files of representations...") + return ch.utils.data.ConcatDataset(ds_list) + + +def calculate_metadata(loader, num_classes=None, filename=None): + """Calculates mean and standard deviation of the deep features over + a given set of images. + Args: + loader : torch data loader + num_classes (int): Number of classes in the dataset + filename (str): Optional filepath to cache metadata. Recommended + for large dataset_classes like ImageNet. + + Returns: + metadata (dict): Dictionary with desired statistics. + """ + + if filename is not None and os.path.exists(filename): + print("loading Metadata from ", filename) + return ch.load(filename) + + # Calculate number of classes if not given + if num_classes is None: + num_classes = 1 + for batch in loader: + y = batch[1] + print(y) + num_classes = max(num_classes, y.max().item() + 1) + + eye = ch.eye(num_classes) + + X_bar, y_bar, y_max, n = 0, 0, 0, 0 + + # calculate means and maximum + print("Calculating means") + for ans in tqdm(loader, total=len(loader)): + X, y = ans[:2] + X_bar += X.sum(0) + y_bar += eye[y].sum(0) + y_max = max(y_max, y.max()) + n += y.size(0) + X_bar = X_bar.float() / n + y_bar = y_bar.float() / n + + # calculate std + X_std, y_std = 0, 0 + print("Calculating standard deviations") + for ans in tqdm(loader, total=len(loader)): + X, y = ans[:2] + X_std += ((X - X_bar) ** 2).sum(0) + y_std += ((eye[y] - y_bar) ** 2).sum(0) + X_std = ch.sqrt(X_std.float() / n) + y_std = ch.sqrt(y_std.float() / n) + + # calculate maximum regularization + inner_products = 0 + print("Calculating maximum lambda") + for ans in tqdm(loader, total=len(loader)): + X, y = ans[:2] + y_map = (eye[y] - y_bar) / y_std + inner_products += X.t().mm(y_map) * y_std + + inner_products_group = inner_products.norm(p=2, dim=1) + + metadata = { + "X": { + "mean": X_bar, + "std": X_std, + "num_features": X.size()[1:], + "num_examples": n + }, + "y": { + "mean": y_bar, + "std": y_std, + "num_classes": y_max + 1 + }, + "max_reg": { + "group": inner_products_group.abs().max().item() / n, + "nongrouped": inner_products.abs().max().item() / n + } + } + + if filename is not None: + ch.save(metadata, filename) + + return metadata + + +def split_dataset(dataset, Ntotal, val_frac, + batch_size, num_workers, + random_seed=0, shuffle=True, balance=False): + """Splits a given dataset into train and validation + Args: + dataset : Torch dataset + Ntotal: Total number of dataset samples + val_frac: Fraction to reserve for validation + batch_size (int): Batch size for output loader + num_workers (int): Number of workers to use for output loader + random_seed (int): Random seed + shuffle (bool): Whether or not to shuffle output data loaoder + balance (bool): Whether or not to balance output data loader + (only relevant for some language models) + + Returns: + split_datasets (list): List of dataset_classes (one each for train and val) + split_loaders (list): List of loaders (one each for train and val) + """ + + Nval = math.floor(Ntotal * val_frac) + train_ds, val_ds = ch.utils.data.random_split(dataset, + [Ntotal - Nval, Nval], + generator=ch.Generator().manual_seed(random_seed)) + if balance: + val_ds = balance_dataset(val_ds) + split_datasets = [train_ds, val_ds] + + split_loaders = [] + for ds in split_datasets: + split_loaders.append(ch.utils.data.DataLoader(ds, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle)) + return split_datasets, split_loaders diff --git a/sparsification/glmBasedSparsification.py b/sparsification/glmBasedSparsification.py new file mode 100644 index 0000000000000000000000000000000000000000..4a681147b4394281069c3a6bf0596baf435ecae0 --- /dev/null +++ b/sparsification/glmBasedSparsification.py @@ -0,0 +1,130 @@ +import logging +import os +import shutil + +import numpy as np +import pandas as pd +import torch +from glm_saga.elasticnet import glm_saga +from torch import nn + +from sparsification.FeatureSelection import FeatureSelectionFitting +from sparsification import data_helpers +from sparsification.utils import get_default_args, compute_features_and_metadata, select_in_loader, get_feature_loaders + + +def get_glm_selection(feature_loaders, metadata, args, num_classes, device, n_features_to_select, folder): + num_features = metadata["X"]["num_features"][0] + fittingClass = FeatureSelectionFitting(num_features, num_classes, args, 0.8, + n_features_to_select, + 0.1,folder, + lookback=3, tol=1e-4, + epsilon=1,) + to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device) + selected_features = torch.tensor([i for i in range(num_features) if i not in to_drop]) + return selected_features + + +def compute_feature_selection_and_assignment(model, train_loader, test_loader, log_folder,num_classes, seed, select_features = 50): + feature_loaders, metadata, device,args = get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ) + + if os.path.exists(log_folder / f"SlDD_Selection_{select_features}.pt"): + feature_selection = torch.load(log_folder / f"SlDD_Selection_{select_features}.pt") + else: + used_features = model.linear.weight.shape[1] + if used_features != select_features: + selection_folder = log_folder / "sldd_selection" # overwrite with None to prevent saving + feature_selection = get_glm_selection(feature_loaders, metadata, args, + num_classes, + device,select_features, selection_folder + ) + else: + feature_selection = model.linear.selection + torch.save(feature_selection, log_folder / f"SlDD_Selection_{select_features}.pt") + feature_loaders = select_in_loader(feature_loaders, feature_selection) + mean, std = metadata["X"]["mean"], metadata["X"]["std"] + mean_to_pass_in = mean + std_to_pass_in = std + if len(mean) != feature_selection.shape[0]: + mean_to_pass_in = mean[feature_selection] + std_to_pass_in = std[feature_selection] + + sparse_matrices, biases = fit_glm(log_folder, mean_to_pass_in, std_to_pass_in, feature_loaders, num_classes, select_features) + + return feature_selection, sparse_matrices, biases, mean, std + + +def fit_glm(log_dir,mean, std , feature_loaders, num_classes, select_features = 50): + output_folder = log_dir / "glm_path" + if not output_folder.exists() or len(list(output_folder.iterdir())) != 102: + shutil.rmtree(output_folder, ignore_errors=True) + output_folder.mkdir(exist_ok=True, parents=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + linear = nn.Linear(select_features, num_classes).to(device) + for p in [linear.weight, linear.bias]: + p.data.zero_() + print("Preparing normalization preprocess and indexed dataloader") + metadata = {"X": {"mean": mean, "std": std},} + preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'], + metadata=metadata, + device=linear.weight.device) + + print("Calculating the regularization path") + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + params = glm_saga(linear, + feature_loaders['train'], + 0.1, + 2000, + 0.99, k=100, + val_loader=feature_loaders['val'], + test_loader=feature_loaders['test'], + n_classes=num_classes, + checkpoint=str(output_folder), + verbose=200, + tol=1e-4, # Change for ImageNet + lookbehind=5, + lr_decay_factor=1, + group=False, + epsilon=0.001, + metadata=None, # To let it be recomputed + preprocess=preprocess, ) + results = load_glm(output_folder) + sparse_matrices = results["weights"] + biases = results["biases"] + + return sparse_matrices, biases + +def load_glm(result_dir): + Nlambda = max([int(f.split('params')[1].split('.pth')[0]) + for f in os.listdir(result_dir) if 'params' in f]) + 1 + + print(f"Loading regularization path of length {Nlambda}") + + params_dict = {i: torch.load(os.path.join(result_dir, f"params{i}.pth"), + map_location=torch.device('cpu')) for i in range(Nlambda)} + + regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)] + weights = [params_dict[i]['weight'] for i in range(Nlambda)] + biases = [params_dict[i]['bias'] for i in range(Nlambda)] + + metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []} + + for k in metrics.keys(): + for i in range(Nlambda): + metrics[k].append(params_dict[i]['metrics'][k]) + metrics[k] = 100 * np.stack(metrics[k]) + metrics = pd.DataFrame(metrics) + metrics = metrics.rename(columns={'acc_tr': 'acc_train'}) + + # weights_stacked = ch.stack(weights) + # sparsity = ch.sum(weights_stacked != 0, dim=2).numpy() + sparsity = np.array([torch.sum(w != 0, dim=1).numpy() for w in weights]) + + return {'metrics': metrics, + 'regularization_strengths': regularization_strengths, + 'weights': weights, + 'biases': biases, + 'sparsity': sparsity, + 'weight_dense': weights[-1], + 'bias_dense': biases[-1]} diff --git a/sparsification/qsenn.py b/sparsification/qsenn.py new file mode 100644 index 0000000000000000000000000000000000000000..45eb1bde64846d26996962f7a1c3b8d8e0ffa6ab --- /dev/null +++ b/sparsification/qsenn.py @@ -0,0 +1,63 @@ +import numpy as np +import torch + +from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment + + +def compute_qsenn_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed,n_features, per_class = 5): + feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader, + test_loader, + log_folder, num_classes, seed, n_features) + weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices[:-1], biases[:-1], per_class) # Last one in regularisation path has no regularisation + print(f"Number of nonzeros in weight matrix: {torch.sum(weight_sparse != 0)}") + return feature_sel, weight_sparse, bias_sparse, mean, std +def get_sparsified_weights_for_factor(weights, biases, factor,): + no_reg_result_mat, no_reg_result_bias = weights[-1], biases[-1] + goal_nonzeros = factor * no_reg_result_mat.shape[0] + values = no_reg_result_mat.flatten() + values = values[values != 0] + values = -(torch.sort(-torch.abs(values))[0]) + if goal_nonzeros < len(values): + threshold = (values[int(goal_nonzeros) - 1] + values[int(goal_nonzeros)]) / 2 + else: + threshold = values[-1] + max_val = torch.max(torch.abs(values)) + weight_sparse = discretize_2_bins_to_threshold(no_reg_result_mat, threshold, max_val) + sel_idx = len(weights) - 1 + positive_weights_per_class = np.array(torch.sum(weight_sparse > 0, dim=1)) + negative_weights_per_class = np.array(torch.sum(weight_sparse < 0, dim=1)) + total_weight_count_per_class = positive_weights_per_class - negative_weights_per_class + max_bias = torch.max(torch.abs(biases[sel_idx])) + bias_sparse = torch.ones_like(biases[sel_idx]) * max_bias + diff_n_weight = total_weight_count_per_class - np.min(total_weight_count_per_class) + steps = np.max(diff_n_weight) + single_step = 2 * max_bias / steps + bias_sparse = bias_sparse - torch.tensor(diff_n_weight) * single_step + bias_sparse = torch.clamp(bias_sparse, -max_bias, max_bias) + return weight_sparse, bias_sparse + + +def discretize_2_bins_to_threshold(data, treshold, max): + boundaries = torch.tensor([-max, -treshold, treshold, max], device=data.device) + bucketized_tensor = torch.bucketize(data, boundaries) + means = torch.tensor([-max, 0, max], device=data.device) + for i in range(len(means)): + if means[i] == 0: + break + positive_index = int(len(means) / 2 + 1) + i + positive_bucket = data[bucketized_tensor == positive_index + 1] + negative_bucket = data[bucketized_tensor == i + 1] + sum = 0 + total = 0 + for bucket in [positive_bucket, negative_bucket]: + if len(bucket) == 0: + continue + sum += torch.sum(torch.abs(bucket)) + total += len(bucket) + if total == 0: + continue + avg = sum / total + means[i] = -avg + means[positive_index] = avg + discretized_tensor = means.cpu()[bucketized_tensor.cpu() - 1].to(bucketized_tensor.device) + return discretized_tensor \ No newline at end of file diff --git a/sparsification/sldd.py b/sparsification/sldd.py new file mode 100644 index 0000000000000000000000000000000000000000..3eeb3733797950107c6a0987fe242a0bfe2e732a --- /dev/null +++ b/sparsification/sldd.py @@ -0,0 +1,44 @@ +import numpy as np +import torch + +from sparsification.glmBasedSparsification import compute_feature_selection_and_assignment + + +def compute_sldd_feature_selection_and_assignment(model, train_loader, test_loader, log_folder, num_classes, seed, + per_class=5, select_features=50): + feature_sel, sparse_matrices, biases, mean, std = compute_feature_selection_and_assignment(model, train_loader, + test_loader, + log_folder, num_classes, + seed, select_features=select_features) + weight_sparse, bias_sparse = get_sparsified_weights_for_factor(sparse_matrices,biases, + per_class) # Last one in regularisation path has none + return feature_sel, weight_sparse, bias_sparse, mean, std + +def get_sparsified_weights_for_factor(sparse_layer,biases,keep_per_class, drop_rate=0.5): + nonzero_entries = [torch.sum(torch.count_nonzero(sparse_layer[i])) for i in range(len(sparse_layer))] + mean_sparsity = np.array([nonzero_entries[i] / sparse_layer[i].shape[0] for i in range(len(sparse_layer))]) + factor =keep_per_class / drop_rate + # Get layer with desired sparsity + sparse_enough = mean_sparsity <= factor + sel_idx = np.argmax(sparse_enough * mean_sparsity) + if sel_idx == 0 and np.sum(mean_sparsity) > 1: # sometimes first one is odd + sparse_enough[0] = False + sel_idx = np.argmax(sparse_enough * mean_sparsity) + selected_weight = sparse_layer[sel_idx] + selected_bias = biases[sel_idx] + # only keep 5 per class on average + weight_5_per_matrix = set_lowest_percent_to_zero(selected_weight,5) + + return weight_5_per_matrix,selected_bias + + +def set_lowest_percent_to_zero(matrix, keep_per): + nonzero_indices = torch.nonzero(matrix) + values = torch.tensor([matrix[x[0], x[1]] for x in nonzero_indices]) + sorted_indices = torch.argsort(torch.abs(values)) + total_allowed = int(matrix.shape[0] * keep_per) + sorted_indices = sorted_indices[:-total_allowed] + nonzero_indices_to_zero = [nonzero_indices[x] for x in sorted_indices] + for to_zero in nonzero_indices_to_zero: + matrix[to_zero[0], to_zero[1]] = 0 + return matrix \ No newline at end of file diff --git a/sparsification/utils.py b/sparsification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e960e5c4131032242e41bf7ff5a8a02b571fc89 --- /dev/null +++ b/sparsification/utils.py @@ -0,0 +1,159 @@ +from argparse import ArgumentParser + +import torch + +#from sparsification.glm_saga import glm_saga +from sparsification import feature_helpers + + +def safe_zip(*args): + for iterable in args[1:]: + if len(iterable) != len(args[0]): + print("Unequally sized iterables to zip, printing lengths") + for i, entry in enumerate(args): + print(i, len(entry)) + raise ValueError("Unequally sized iterables to zip") + return zip(*args) + + +def compute_features_and_metadata(args, train_loader, test_loader, model, out_dir_feats, num_classes, + ): + print("Computing/loading deep features...") + + Ntotal = len(train_loader.dataset) + feature_loaders = {} + # Compute Features for not augmented train and test set + train_loader_transforms = train_loader.dataset.transform + test_loader_transforms = test_loader.dataset.transform + train_loader.dataset.transform = test_loader_transforms + for mode, loader in zip(['train', 'test', ], [train_loader, test_loader, ]): # + print(f"For {mode} set...") + + sink_path = f"{out_dir_feats}/features_{mode}" + metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" + + feature_ds, feature_loader = feature_helpers.compute_features(loader, + model, + dataset_type=args.dataset_type, + pooled_output=None, + batch_size=args.batch_size, + num_workers=0, # args.num_workers, + shuffle=(mode == 'test'), + device=args.device, + filename=sink_path, n_epoch=1, + balance=False, + ) # args.balance if mode == 'test' else False) + + if mode == 'train': + metadata = feature_helpers.calculate_metadata(feature_loader, + num_classes=num_classes, + filename=metadata_path) + if metadata["max_reg"]["group"] == 0.0: + return None, False + split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds, + Ntotal, + val_frac=args.val_frac, + batch_size=args.batch_size, + num_workers=args.num_workers, + random_seed=args.random_seed, + shuffle=True, + balance=False) + feature_loaders.update({mm: add_index_to_dataloader(split_loaders[mi]) + for mi, mm in enumerate(['train', 'val'])}) + + else: + feature_loaders[mode] = feature_loader + train_loader.dataset.transform = train_loader_transforms + return feature_loaders, metadata + +def get_feature_loaders(seed, log_folder,train_loader, test_loader, model, num_classes, ): + args = get_default_args() + args.random_seed = seed + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + feature_folder = log_folder / "features" + feature_loaders, metadata, = compute_features_and_metadata(args, train_loader, test_loader, model, + feature_folder + , + num_classes, + ) + return feature_loaders, metadata, device,args +def add_index_to_dataloader(loader, sample_weight=None,): + return torch.utils.data.DataLoader( + IndexedDataset(loader.dataset, sample_weight=sample_weight), + batch_size=loader.batch_size, + sampler=loader.sampler, + num_workers=loader.num_workers, + collate_fn=loader.collate_fn, + pin_memory=loader.pin_memory, + drop_last=loader.drop_last, + timeout=loader.timeout, + worker_init_fn=loader.worker_init_fn, + multiprocessing_context=loader.multiprocessing_context + ) + + +class IndexedDataset(torch.utils.data.Dataset): + def __init__(self, ds, sample_weight=None): + super(torch.utils.data.Dataset, self).__init__() + self.dataset = ds + self.sample_weight = sample_weight + + def __getitem__(self, index): + val = self.dataset[index] + if self.sample_weight is None: + return val + (index,) + else: + weight = self.sample_weight[index] + return val + (weight, index) + + def __len__(self): + return len(self.dataset) + + +def get_default_args(): + # Default args from glm_saga, https://github.com/MadryLab/glm_saga + parser = ArgumentParser() + parser.add_argument('--dataset', type=str, help='dataset name') + parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]') + parser.add_argument('--dataset-path', type=str, help='path to dataset') + parser.add_argument('--model-path', type=str, help='path to model checkpoint') + parser.add_argument('--arch', type=str, help='model architecture type') + parser.add_argument('--out-path', help='location for saving results') + parser.add_argument('--cache', action='store_true', help='cache deep features') + parser.add_argument('--balance', action='store_true', help='balance classes for evaluation') + + parser.add_argument('--device', default='cuda') + parser.add_argument('--random-seed', default=0) + parser.add_argument('--num-workers', type=int, default=2) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--val-frac', type=float, default=0.1) + parser.add_argument('--lr-decay-factor', type=float, default=1) + parser.add_argument('--lr', type=float, default=0.1) + parser.add_argument('--alpha', type=float, default=0.99) + parser.add_argument('--max-epochs', type=int, default=2000) + parser.add_argument('--verbose', type=int, default=200) + parser.add_argument('--tol', type=float, default=1e-4) + parser.add_argument('--lookbehind', type=int, default=3) + parser.add_argument('--lam-factor', type=float, default=0.001) + parser.add_argument('--group', action='store_true') + args = parser.parse_args() + + args = parser.parse_args() + return args + + +def select_in_loader(feature_loaders, feature_selection): + for dataset in feature_loaders["train"].dataset.dataset.dataset.datasets: # Val is indexed via the same dataset as train + tensors = list(dataset.tensors) + if tensors[0].shape[1] == len(feature_selection): + continue + tensors[0] = tensors[0][:, feature_selection] + dataset.tensors = tensors + for dataset in feature_loaders["test"].dataset.datasets: + tensors = list(dataset.tensors) + if tensors[0].shape[1] == len(feature_selection): + continue + tensors[0] = tensors[0][:, feature_selection] + dataset.tensors = tensors + return feature_loaders + diff --git a/tmp/Datasets/CUB200/CUB_200_2011/README b/tmp/Datasets/CUB200/CUB_200_2011/README new file mode 100644 index 0000000000000000000000000000000000000000..4cf4b8f6a8e963af922b4c320df3c9700af95914 --- /dev/null +++ b/tmp/Datasets/CUB200/CUB_200_2011/README @@ -0,0 +1,140 @@ +=========================================== +The Caltech-UCSD Birds-200-2011 Dataset +=========================================== + +For more information about the dataset, visit the project website: + + http://www.vision.caltech.edu/visipedia + +If you use the dataset in a publication, please cite the dataset in +the style described on the dataset website (see url above). + +Directory Information +--------------------- + +- images/ + The images organized in subdirectories based on species. See + IMAGES AND CLASS LABELS section below for more info. +- parts/ + 15 part locations per image. See PART LOCATIONS section below + for more info. +- attributes/ + 322 binary attribute labels from MTurk workers. See ATTRIBUTE LABELS + section below for more info. + + + +========================= +IMAGES AND CLASS LABELS: +========================= +Images are contained in the directory images/, with 200 subdirectories (one for each bird species) + +------- List of image files (images.txt) ------ +The list of image file names is contained in the file images.txt, with each line corresponding to one image: + + +------------------------------------------ + + +------- Train/test split (train_test_split.txt) ------ +The suggested train/test split is contained in the file train_test_split.txt, with each line corresponding to one image: + + + +where corresponds to the ID in images.txt, and a value of 1 or 0 for denotes that the file is in the training or test set, respectively. +------------------------------------------------------ + + +------- List of class names (classes.txt) ------ +The list of class names (bird species) is contained in the file classes.txt, with each line corresponding to one class: + + +-------------------------------------------- + + +------- Image class labels (image_class_labels.txt) ------ +The ground truth class labels (bird species labels) for each image are contained in the file image_class_labels.txt, with each line corresponding to one image: + + + +where and correspond to the IDs in images.txt and classes.txt, respectively. +--------------------------------------------------------- + + + + + +========================= +BOUNDING BOXES: +========================= + +Each image contains a single bounding box label. Bounding box labels are contained in the file bounding_boxes.txt, with each line corresponding to one image: + + + +where corresponds to the ID in images.txt, and , , , and are all measured in pixels + + + + +========================= +PART LOCATIONS: +========================= + +------- List of part names (parts/parts.txt) ------ +The list of all part names is contained in the file parts/parts.txt, with each line corresponding to one part: + + +------------------------------------------ + + +------- Part locations (parts/part_locs.txt) ------ +The set of all ground truth part locations is contained in the file parts/part_locs.txt, with each line corresponding to the annotation of a particular part in a particular image: + + + +where and correspond to the IDs in images.txt and parts/parts.txt, respectively. and denote the pixel location of the center of the part. is 0 if the part is not visible in the image and 1 otherwise. +---------------------------------------------------------- + + +------- MTurk part locations (parts/part_click_locs.txt) ------ +A set of multiple part locations for each image and part, as perceived by multiple MTurk users is contained in parts/part_click_locs.txt, with each line corresponding to the annotation of a particular part in a particular image by a different MTurk worker: + +