Spaces:
Running
Running
from abc import abstractmethod | |
from copy import deepcopy | |
import enum | |
import torch | |
from torch import nn | |
import os | |
from .model_fbs import DomainDynamicConv2d | |
#from methods.utils.data import get_source_dataloader, get_source_normal_aug_dataloader, get_target_dataloaders | |
#from models.resnet_cifar.model_manager import ResNetCIFARManager | |
from utils.common.others import get_cur_time_str | |
from utils.dl.common.env import set_random_seed | |
from utils.dl.common.model import get_model_latency, get_model_size, get_module, set_module | |
from utils.common.log import logger | |
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup | |
from utils.third_party.nni_new.compression.pytorch.utils.mask_conflict import GroupMaskConflict, ChannelMaskConflict, CatMaskPadding | |
def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None, fix_group=False, fix_channel=True, fix_padding=False): | |
if isinstance(masks, str): | |
# if the input is the path of the mask_file | |
assert os.path.exists(masks) | |
masks = torch.load(masks) | |
assert len(masks) > 0, 'Mask tensor cannot be empty' | |
# if the user uses the model and dummy_input to trace the model, we | |
# should get the traced model handly, so that, we only trace the | |
# model once, GroupMaskConflict and ChannelMaskConflict will reuse | |
# this traced model. | |
if traced is None: | |
assert model is not None and dummy_input is not None | |
training = model.training | |
model.eval() | |
# We need to trace the model in eval mode | |
traced = torch.jit.trace(model, dummy_input) | |
model.train(training) | |
if fix_group: | |
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced) | |
masks = fix_group_mask.fix_mask() | |
if fix_channel: | |
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced) | |
masks = fix_channel_mask.fix_mask() | |
if fix_padding: | |
padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced) | |
masks = padding_cat_mask.fix_mask() | |
return masks | |
class FeatureBoosting(nn.Module): | |
def __init__(self, w: torch.Tensor): | |
super(FeatureBoosting, self).__init__() | |
assert w.dim() == 1 | |
self.w = nn.Parameter(w.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) | |
def forward(self, x): | |
return x * self.w | |
class FBSSubModelExtractor: | |
def extract_submodel_via_a_sample(self, fbs_model: nn.Module, sample: torch.Tensor): | |
assert sample.dim() == 4 and sample.size(0) == 1 | |
fbs_model.eval() | |
o1 = fbs_model(sample) | |
pruning_info = {} | |
pruning_masks = {} | |
for layer_name, layer in fbs_model.named_modules(): | |
if not isinstance(layer, DomainDynamicConv2d): | |
continue | |
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
w = get_module(fbs_model, layer_name).cached_w.squeeze() | |
unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
pruning_info[layer_name] = w | |
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
no_gate_model = deepcopy(fbs_model) | |
for name, layer in no_gate_model.named_modules(): | |
if not isinstance(layer, DomainDynamicConv2d): | |
continue | |
# layer.bn.weight.data.mul_(pruning_info[name]) | |
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) | |
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
torch.save(pruning_masks, tmp_mask_path) | |
pruned_model = no_gate_model | |
pruned_model.eval() | |
model_speedup = ModelSpeedup(pruned_model, sample, tmp_mask_path, sample.device) | |
model_speedup.speedup_model() | |
os.remove(tmp_mask_path) | |
# add feature boosting module | |
for layer_name, feature_boosting_w in pruning_info.items(): | |
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) | |
pruned_model_size = get_model_size(pruned_model, True) | |
pruned_model.eval() | |
o2 = pruned_model(sample) | |
diff = ((o1 - o2) ** 2).sum() | |
logger.info(f'pruned model size: {pruned_model_size:.3f}MB, diff: {diff}') | |
return pruned_model | |
def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor): | |
pass | |
def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor): | |
pass | |
def extract_submodel_via_samples(self, fbs_model: nn.Module, samples: torch.Tensor): | |
assert samples.dim() == 4 | |
fbs_model = deepcopy(fbs_model) | |
# fbs_model.eval() | |
# fbs_model(samples) | |
self.generate_pruning_strategy(fbs_model, samples) | |
pruning_info = {} | |
pruning_masks = {} | |
for layer_name, layer in fbs_model.named_modules(): | |
if not isinstance(layer, DomainDynamicConv2d): | |
continue | |
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
w = get_module(fbs_model, layer_name).cached_w.squeeze() # 2-dim | |
w = self.get_final_w(fbs_model, samples, layer_name, w) | |
unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
pruning_info[layer_name] = w | |
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
no_gate_model = deepcopy(fbs_model) | |
for name, layer in no_gate_model.named_modules(): | |
if not isinstance(layer, DomainDynamicConv2d): | |
continue | |
# layer.bn.weight.data.mul_(pruning_info[name]) | |
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) | |
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
torch.save(pruning_masks, tmp_mask_path) | |
pruned_model = no_gate_model | |
pruned_model.eval() | |
model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device) | |
model_speedup.speedup_model() | |
os.remove(tmp_mask_path) | |
# add feature boosting module | |
for layer_name, feature_boosting_w in pruning_info.items(): | |
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) | |
return pruned_model, pruning_info | |
def extract_submodel_via_samples_and_last_submodel(self, fbs_model: nn.Module, samples: torch.Tensor, | |
last_submodel: nn.Module, last_pruning_info: dict): | |
assert samples.dim() == 4 | |
fbs_model = deepcopy(fbs_model) | |
# fbs_model.eval() | |
# fbs_model(samples) | |
self.generate_pruning_strategy(fbs_model, samples) | |
pruning_info = {} | |
pruning_masks = {} | |
# some tricks | |
incrementally_updated_layers = [] | |
for layer_name, layer in fbs_model.named_modules(): | |
if not isinstance(layer, DomainDynamicConv2d): | |
continue | |
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
w = get_module(fbs_model, layer_name).cached_w.squeeze() # 2-dim | |
w = self.get_final_w(fbs_model, samples, layer_name, w) | |
unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
pruning_info[layer_name] = w | |
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
if layer.raw_conv2d.bias is not None: | |
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
# some tricks | |
if last_pruning_info is not None: | |
last_w = last_pruning_info[layer_name] | |
intersection_ratio = ((w > 0) * (last_w > 0)).sum() / (last_w > 0).sum() | |
if intersection_ratio > 0.: | |
incrementally_updated_layers += [layer_name] # that is, only similar layers are transferable | |
no_gate_model = deepcopy(fbs_model) | |
for name, layer in no_gate_model.named_modules(): | |
if not isinstance(layer, DomainDynamicConv2d): | |
continue | |
# layer.bn.weight.data.mul_(pruning_info[name]) | |
set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) | |
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
torch.save(pruning_masks, tmp_mask_path) | |
pruned_model = no_gate_model | |
pruned_model.eval() | |
model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device) | |
model_speedup.speedup_model() | |
os.remove(tmp_mask_path) | |
# add feature boosting module | |
for layer_name, feature_boosting_w in pruning_info.items(): | |
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) | |
# some tricks | |
# incrementally updating (borrow some weights from last_pruned_model) | |
for layer_name in incrementally_updated_layers: | |
cur_filter_i, last_filter_i = 0, 0 | |
for i, (w_factor, last_w_factor) in enumerate(zip(pruning_info[layer_name], last_pruning_info[layer_name])): | |
if w_factor > 0 and last_w_factor > 0: # the filter is shared | |
cur_conv2d, last_conv2d = get_module(pruned_model, layer_name + '.0'), get_module(last_submodel, layer_name + '.0') | |
cur_conv2d.weight.data[cur_filter_i] = last_conv2d.weight.data[last_filter_i] | |
cur_bn, last_bn = get_module(pruned_model, layer_name + '.1'), get_module(last_submodel, layer_name + '.1') | |
cur_bn.weight.data[cur_filter_i] = last_bn.weight.data[last_filter_i] | |
cur_bn.bias.data[cur_filter_i] = last_bn.bias.data[last_filter_i] | |
cur_bn.running_mean.data[cur_filter_i] = last_bn.running_mean.data[last_filter_i] | |
cur_bn.running_var.data[cur_filter_i] = last_bn.running_var.data[last_filter_i] | |
cur_fw, last_fw = get_module(pruned_model, layer_name + '.2'), get_module(last_submodel, layer_name + '.2') | |
cur_fw.w.data[0, cur_filter_i] = last_fw.w.data[0, last_filter_i] | |
if w_factor > 0: | |
cur_filter_i += 1 | |
if last_w_factor > 0: | |
last_filter_i += 1 | |
return pruned_model, pruning_info | |
def absorb_sub_model(self, fbs_model: nn.Module, sub_model: nn.Module, pruning_info: dict, alpha=1.): | |
if alpha == 0.: | |
return | |
for layer_name, feature_boosting_w in pruning_info.items(): | |
unpruned_filters_index = feature_boosting_w.nonzero(as_tuple=True)[0] | |
fbs_layer = get_module(fbs_model, layer_name) | |
sub_model_layer = get_module(sub_model, layer_name) | |
for fi_in_sub_layer, fi_in_fbs_layer in enumerate(unpruned_filters_index): | |
fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] = (1. - alpha) * fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] + \ | |
alpha * sub_model_layer[0].weight.data[fi_in_sub_layer] | |
for k in ['weight', 'bias', 'running_mean', 'running_var']: | |
getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] = (1. - alpha) * getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] + \ | |
alpha * getattr(sub_model_layer[1], k).data[fi_in_sub_layer] | |
class DAFBSSubModelExtractor(FBSSubModelExtractor): | |
def __init__(self) -> None: | |
super().__init__() | |
# self.debug_sample_i = 0 | |
# self.last_final_ws = None | |
def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor): | |
with torch.no_grad(): | |
fbs_model.eval() | |
self.cur_output = fbs_model(samples) | |
def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor): | |
# import matplotlib.pyplot as plt | |
# plt.imshow(w.cpu().numpy(), cmap='Greys') | |
# # plt.colorbar() | |
# plt.xlabel('Filters') | |
# plt.ylabel('Samples') | |
# plt.tight_layout() | |
# plt.savefig(os.path.join(res_save_dir, f'{layer_name}.png'), dpi=300) | |
# plt.clf() | |
# w_sum = w.sum(0) | |
# w_argsort = w_sum.argsort(descending=True) | |
# return w[self.debug_sample_i] | |
# x = self.cur_output | |
# each_sample_entropy = -(x.softmax(1) * x.log_softmax(1)).sum(1) | |
# hardest_sample_index = w.sum(1).argmax() | |
# return w[hardest_sample_index] | |
# [0.0828, 0.1017, 0.0575, 0.3081, 0.1511, 0.3634, 0.3388, 0.3942, 0.2475, 0.3371, 0.5837, 0.145, 0.4428, 0.2159, 0.4028] 0.27815999999999996 | |
x = self.cur_output | |
each_sample_entropy = -(x.logits.softmax(1) * x.logits.log_softmax(1)).sum(1) | |
hardest_sample_index = each_sample_entropy.argmax() | |
res = w[hardest_sample_index] | |
return res | |
# if self.last_final_ws is not None: | |
# intersection_ratio = (self.last_final_w == res).sum() / (res > 0).sum() | |
# print('intersection ratio: ', intersection_ratio) | |
# self.last_final_ws[layer_name] = res | |
# indices = (-w).sum(0).topk((w[0] == 0).sum())[1] | |
# boosting = w.max(0)[0] | |
# boosting[indices] = 0. | |
# return boosting | |
# return w[0] | |
def tent_as_detector(model, x, num_iters=1, lr=1e-4, l1_wd=0., strategy='ours'): | |
model = deepcopy(model) | |
before_model = deepcopy(model) | |
from methods.tent import tent | |
optimizer = torch.optim.SGD( | |
model.parameters(), lr=lr, weight_decay=l1_wd) | |
from models.resnet_cifar.model_manager import ResNetCIFARManager | |
tented_model = tent.Tent(model, optimizer, ResNetCIFARManager, steps=num_iters) | |
tent.configure_model(model) | |
tented_model(x) | |
filters_sen_info = {} | |
last_conv_name = None | |
for (name, m1), m2 in zip(model.named_modules(), before_model.modules()): | |
if isinstance(m1, nn.Conv2d): | |
last_conv_name = name | |
if not isinstance(m1, nn.BatchNorm2d): | |
continue | |
with torch.no_grad(): | |
features_weight_diff = ((m1.weight.data - m2.weight.data).abs()) | |
features_bias_diff = ((m1.bias.data - m2.bias.data).abs()) | |
features_diff = features_weight_diff + features_bias_diff | |
features_diff_order = features_diff.argsort(descending=False) | |
if strategy == 'ours': | |
untrained_filters_index = features_diff_order[: int(len(features_diff) * 0.8)] | |
elif strategy == 'random': | |
untrained_filters_index = torch.randperm(len(features_diff))[: int(len(features_diff) * 0.8)] | |
elif strategy == 'inversed_ours': | |
untrained_filters_index = features_diff_order.flip(0)[: int(len(features_diff) * 0.8)] | |
elif strategy == 'none': | |
untrained_filters_index = None | |
filters_sen_info[name] = dict(untrained_filters_index=untrained_filters_index, conv_name=last_conv_name) | |
return filters_sen_info | |
class SGDF(torch.optim.SGD): | |
def step(self, model, conv_filters_sen_info, filters_sen_info, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
weight_decay = group['weight_decay'] | |
momentum = group['momentum'] | |
dampening = group['dampening'] | |
nesterov = group['nesterov'] | |
# assert len([i for i in model.named_parameters()]) == len([j for j in group['params']]) | |
for (name, _), p in zip(model.named_parameters(), group['params']): | |
if p.grad is None: | |
continue | |
layer_name = '.'.join(name.split('.')[0:-1]) | |
if layer_name in filters_sen_info.keys(): | |
untrained_filters_index = filters_sen_info[layer_name]['untrained_filters_index'] | |
elif layer_name in conv_filters_sen_info.keys(): | |
untrained_filters_index = conv_filters_sen_info[layer_name]['untrained_filters_index'] | |
else: | |
untrained_filters_index = [] | |
d_p = p.grad | |
if weight_decay != 0: | |
d_p = d_p.add(p, alpha=weight_decay) | |
if momentum != 0: | |
param_state = self.state[p] | |
if 'momentum_buffer' not in param_state: | |
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() | |
else: | |
buf = param_state['momentum_buffer'] | |
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) | |
if nesterov: | |
d_p = d_p.add(buf, alpha=momentum) | |
else: | |
d_p = buf | |
d_p[untrained_filters_index] = 0. | |
p.add_(d_p, alpha=-group['lr']) | |
return loss | |
if __name__ == '__main__': | |
set_random_seed(0) | |
import sys | |
tag = sys.argv[1] | |
# alpha = 0.4 | |
alpha = 0.2 | |
# alpha = float(sys.argv[1]) | |
fbs_model_path = sys.argv[1] | |
cur_time_str = get_cur_time_str() | |
res_save_dir = f'logs/experiments_trial/CIFAR100C/ours_fbs_more_challenging/{cur_time_str[0:8]}/{cur_time_str[8:]}-{tag}' | |
os.makedirs(res_save_dir) | |
import shutil | |
shutil.copytree(os.path.dirname(__file__), | |
os.path.join(res_save_dir, 'method'), ignore=shutil.ignore_patterns('*.pt', '*.pth', 'log', '__pycache__')) | |
logger.info(f'res save dir: {res_save_dir}') | |
# model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220801/152138-0.6_l1wd=1e-8/best_model_0.80.pt') | |
# model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220801/232913-sample_subnetwork/best_model_0.80.pt') | |
model = torch.load(fbs_model_path) | |
# model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220729/002444-0.4/best_model_0.40.pt') | |
# import sys | |
# sys.path.append('/data/xgf/legodnn_and_domain_adaptation') | |
xgf_model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220731/224212-cifar10_svhn_raw/last_model.pt') | |
# xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s4/20220607/204211/last_model.pt') | |
# test_dataloader = get_source_dataloader('CIFAR100', 256, 4, 'test', False, False, False) | |
# test_dataloader = get_target_dataloaders('CIFAR100C', [7], 128, 4, 'test', False, False, False)[0] # snow, xgf 0.3914 | |
# test_dataloaders = get_target_dataloaders('CIFAR100C', list(range(15)), 128, 4, 'test', False, False, False) # defocus_blur, xgf 0.2836 | |
# test_dataloaders = get_target_dataloaders('RotatedCIFAR100', list(range(18)), 128, 4, 'test', False, False, False) | |
train_dataloaders = [ | |
get_source_dataloader(dataset_name, 128, 4, 'train', True, None, True) for dataset_name in ['SVHN', 'CIFAR10', 'SVHN'] | |
][::-1] * 10 | |
test_dataloaders = [ | |
get_source_dataloader('USPS', 128, 4, 'test', False, False, False), | |
get_source_dataloader('STL10-wo-monkey', 128, 4, 'test', False, False, False), | |
get_source_dataloader('MNIST', 128, 4, 'test', False, False, False), | |
][::-1] * 10 | |
y_offsets = [10, 0, 10][::-1] * 10 | |
domain_names = ['USPS', 'STL10', 'MNIST'][::-1] * 10 | |
# train_dataloader = get_source_dataloader('CIFAR100', 128, 4, 'train', True, None, True) | |
# acc = ResNetCIFARManager.get_accuracy(model, test_dataloader, 'cuda') | |
# print(acc) | |
# baseline_accs = [0.1012, 0.1156, 0.0529, 0.2836, 0.1731, 0.3765, 0.3445, 0.3914, 0.2672, 0.3289, 0.5991, 0.1486, 0.4519, 0.1907, 0.3929] | |
# accs = [] | |
baseline_before, baseline_after, ours_before, ours_after = [], [], [], [] | |
last_pruned_model, last_pruning_info = None, None | |
# y_offset = 0 | |
for ti, (test_dataloader, y_offset) in enumerate(zip(test_dataloaders, y_offsets)): | |
samples, labels = next(iter(test_dataloader)) | |
samples, labels = samples.cuda(), labels.cuda() | |
labels += y_offset | |
def bn_cal(_model: nn.Module): | |
for n, m in _model.named_modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.reset_running_stats() | |
m.training = True | |
m.train() | |
for _ in range(100): # ~one epoch | |
x, y = next(train_dataloaders[ti]) | |
x = x.cuda() | |
_model(samples) | |
def shot(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): | |
# print([n for n, p in model.named_parameters()]) | |
_model.requires_grad_(True) | |
_model.linear.requires_grad_(False) | |
import torch.optim | |
optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) | |
device = 'cuda' | |
for _ in range(100 * num_iters_scale): | |
x = samples | |
_model.train() | |
output = ResNetCIFARManager.forward(_model, x) | |
def Entropy(input_): | |
entropy = -input_ * torch.log(input_ + 1e-5) | |
entropy = torch.sum(entropy, dim=1) | |
return entropy | |
softmax_out = nn.Softmax(dim=1)(output) | |
entropy_loss = torch.mean(Entropy(softmax_out)) | |
msoftmax = softmax_out.mean(dim=0) | |
entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) | |
loss = entropy_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
def shot_w_part_filter(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): | |
# print([n for n, p in model.named_parameters()]) | |
_model.requires_grad_(True) | |
_model.linear.requires_grad_(False) | |
import torch.optim | |
optimizer = SGDF([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) | |
device = 'cuda' | |
filters_sen_info = tent_as_detector(_model, samples, strategy='ours') | |
conv_filters_sen_info = {v['conv_name']: v for _, v in filters_sen_info.items()} | |
for _ in range(100 * num_iters_scale): | |
x = samples | |
_model.train() | |
output = ResNetCIFARManager.forward(_model, x) | |
def Entropy(input_): | |
entropy = -input_ * torch.log(input_ + 1e-5) | |
entropy = torch.sum(entropy, dim=1) | |
return entropy | |
softmax_out = nn.Softmax(dim=1)(output) | |
entropy_loss = torch.mean(Entropy(softmax_out)) | |
msoftmax = softmax_out.mean(dim=0) | |
entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) | |
loss = entropy_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step(_model, conv_filters_sen_info, filters_sen_info) | |
def tent(_model: nn.Module): | |
from methods.tent import tent | |
_model = tent.configure_model(_model) | |
params, param_names = tent.collect_params(_model) | |
optimizer = torch.optim.Adam(params, lr=1e-4) | |
tent_model = tent.Tent(_model, optimizer, ResNetCIFARManager, steps=1) | |
tent.configure_model(_model) | |
tent_model(samples) | |
def tent_configure_bn(_model): | |
"""Configure model for use with tent.""" | |
# train mode, because tent optimizes the model to minimize entropy | |
# _model.train() | |
# # disable grad, to (re-)enable only what tent updates | |
# _model.requires_grad_(False) | |
# configure norm for tent updates: enable grad + force batch statisics | |
for m in _model.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.requires_grad_(True) | |
# force use of batch stats in train and eval modes | |
m.track_running_stats = False | |
m.running_mean = None | |
m.running_var = None | |
# m.track_running_stats = True | |
# m.momentum = 1.0 | |
# # FIXME | |
# from methods.ours_dynamic_filters.extract_submodel import FeatureBoosting | |
# # if isinstance(m, FeatureBoosting): | |
# if m.__class__.__name__ == 'FeatureBoosting': | |
# m.requires_grad_(True) | |
return model | |
def sl(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): | |
_model.requires_grad_(True) | |
_model.linear.requires_grad_(False) | |
import torch.optim | |
optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) | |
device = 'cuda' | |
for _ in range(100 * num_iters_scale): | |
x = samples | |
_model.train() | |
loss = ResNetCIFARManager.forward_to_gen_loss(_model, x, labels) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
model_extractor = DAFBSSubModelExtractor() | |
model1 = model_extractor.extract_submodel_via_a_sample(model,samples[0]) | |
pruned_model, pruning_info = model_extractor.extract_submodel_via_samples_and_last_submodel(model, samples, None, None) | |
# print(pruned_model) | |
# print(get_model_size(pruned_model, True)) | |
# bn_cal(pruned_model) | |
acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset) | |
print(acc) | |
ours_before += [acc] | |
# tent(pruned_model) | |
# bn_cal(pruned_model) | |
shot_w_part_filter(pruned_model, 6e-4, 1, 1e-3) | |
# sl(pruned_model) | |
acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset) | |
print(acc) | |
ours_after += [acc] | |
last_pruned_model, last_pruning_info = deepcopy(pruned_model), deepcopy(pruning_info) | |
model_extractor.absorb_sub_model(model, pruned_model, pruning_info, alpha) | |
# xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s8/20220607/212448/last_model.pt') | |
# xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s4/20220607/204211/last_model.pt') | |
# print(xgf_model) | |
# acc = ResNetCIFARManager.get_accuracy(xgf_model, test_dataloader, 'cuda', y_offset) | |
# print(acc) | |
# baseline_before += [acc] | |
# # tent(xgf_model) | |
# shot(xgf_model) | |
# # sl(xgf_model) | |
# acc = ResNetCIFARManager.get_accuracy(xgf_model, test_dataloader, 'cuda', y_offset) | |
# print(acc) | |
# baseline_after += [acc] | |
# print() | |
# diff = acc - baseline_accs[ti] | |
# print(f'domain {ti}, model size {get_model_size(pruned_model, True):.3f}MB, diff: {diff:.4f}') | |
# print(accs, sum(accs) / len(accs)) | |
import matplotlib.pyplot as plt | |
from visualize.util import * | |
set_figure_settings(3) | |
def avg(arr): | |
return sum(arr) / len(arr) | |
# plt.plot(list(range(len(test_dataloaders))), baseline_before, lw=2, linestyle='--', color=BLUE, label=f'L1 before DA ({avg(baseline_before):.4f})') | |
# plt.plot(list(range(len(test_dataloaders))), baseline_after, lw=2, linestyle='-', color=BLUE, label=f'L1 after DA ({avg(baseline_after):.4f})') | |
plt.plot(list(range(len(test_dataloaders))), ours_before, lw=2, linestyle='--', color=RED, label=f'ours before DA ({avg(ours_before):.4f})') | |
plt.plot(list(range(len(test_dataloaders))), ours_after, lw=2, linestyle='-', color=RED, label=f'ours after DA ({avg(ours_after):.4f})') | |
plt.xlabel('domains') | |
plt.ylabel('accuracy') | |
plt.xticks(list(range(len(domain_names))), domain_names, rotation=90) | |
plt.legend(loc=2, bbox_to_anchor=(1.05, 1.0), fontsize=16) | |
plt.tight_layout() | |
plt.savefig(os.path.join(res_save_dir, 'main.png'), dpi=300) | |
plt.clf() | |
torch.save((baseline_before, baseline_after, ours_before, ours_after), os.path.join(res_save_dir, 'main.png.data')) | |
# with open('./tmp.csv', 'a') as f: | |
# f.write(f'{alpha:.2f},{avg(baseline_after):.4f},{avg(ours_after):.4f}') | |
# std: logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220730/161404-submodel/main.png | |
# accs = [] | |
# for i in tqdm.tqdm(range(100)): | |
# model_extractor.debug_sample_i = i | |
# pruned_model = model_extractor.extract_submodel_via_samples(model, samples) | |
# acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda') | |
# accs += [acc] | |
# import matplotlib.pyplot as plt | |
# plt.plot(list(range(100)), accs) | |
# plt.savefig('./tmp.png', dpi=300) | |
# plt.clf() | |
# ------------------------------ | |
# perf test | |
# sample, _ = next(iter(test_dataloader)) | |
# sample = sample[0: 1].cuda() | |
# pruned_model = FBSSubModelExtractor().extract_submodel_via_a_sample(model, sample) | |
# bs = 1 | |
# def perf_test(model, batch_size, device): | |
# model = model.to(device) | |
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
# # warmup | |
# for _ in range(100): | |
# rand_input = torch.rand((batch_size, 3, 32, 32)).to(device) | |
# o = model(rand_input) | |
# forward_latency = 0. | |
# backward_latency = 0. | |
# for _ in range(100): | |
# rand_input = torch.rand((batch_size, 3, 32, 32)).to(device) | |
# optimizer.zero_grad() | |
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
# s.record() | |
# o = model(rand_input) | |
# e.record() | |
# torch.cuda.synchronize() | |
# forward_latency += s.elapsed_time(e) / 1000. | |
# loss = ((o - 1) ** 2).sum() | |
# s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
# s.record() | |
# loss.backward() | |
# optimizer.step() | |
# e.record() | |
# torch.cuda.synchronize() | |
# backward_latency += s.elapsed_time(e) / 1000. | |
# forward_latency /= 100 | |
# backward_latency /= 100 | |
# print(forward_latency, backward_latency) | |
# for bs in [1, 128]: | |
# for device in ['cuda', 'cpu']: | |
# for m in [model, pruned_model]: | |
# print(bs, device) | |
# perf_test(m, bs, device) |