Spaces:
Running
Running
import torch.nn as nn | |
import torch | |
import numpy as np | |
''' | |
---- 1) FLOPs: floating point operations | |
---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs | |
---- 3) #Conv2d: the number of ‘Conv2d’ layers | |
# -------------------------------------------- | |
# Kai Zhang (github: https://github.com/cszn) | |
# 21/July/2020 | |
# -------------------------------------------- | |
# Reference | |
https://github.com/sovrasov/flops-counter.pytorch.git | |
# If you use this code, please consider the following citation: | |
@inproceedings{zhang2020aim, % | |
title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results}, | |
author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others}, | |
booktitle={European Conference on Computer Vision Workshops}, | |
year={2020} | |
} | |
# -------------------------------------------- | |
''' | |
def get_model_flops(model, input_res, print_per_layer_stat=True, | |
input_constructor=None): | |
assert type(input_res) is tuple, 'Please provide the size of the input image.' | |
assert len(input_res) >= 3, 'Input image should have 3 dimensions.' | |
flops_model = add_flops_counting_methods(model) | |
flops_model.eval().start_flops_count() | |
if input_constructor: | |
input = input_constructor(input_res) | |
_ = flops_model(**input) | |
else: | |
device = list(flops_model.parameters())[-1].device | |
batch = torch.FloatTensor(1, *input_res).to(device) | |
_ = flops_model(batch) | |
if print_per_layer_stat: | |
print_model_with_flops(flops_model) | |
flops_count = flops_model.compute_average_flops_cost() | |
flops_model.stop_flops_count() | |
return flops_count | |
def get_model_activation(model, input_res, input_constructor=None): | |
assert type(input_res) is tuple, 'Please provide the size of the input image.' | |
assert len(input_res) >= 3, 'Input image should have 3 dimensions.' | |
activation_model = add_activation_counting_methods(model) | |
activation_model.eval().start_activation_count() | |
if input_constructor: | |
input = input_constructor(input_res) | |
_ = activation_model(**input) | |
else: | |
device = list(activation_model.parameters())[-1].device | |
batch = torch.FloatTensor(1, *input_res).to(device) | |
_ = activation_model(batch) | |
activation_count, num_conv = activation_model.compute_average_activation_cost() | |
activation_model.stop_activation_count() | |
return activation_count, num_conv | |
def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True, | |
input_constructor=None): | |
assert type(input_res) is tuple | |
assert len(input_res) >= 3 | |
flops_model = add_flops_counting_methods(model) | |
flops_model.eval().start_flops_count() | |
if input_constructor: | |
input = input_constructor(input_res) | |
_ = flops_model(**input) | |
else: | |
batch = torch.FloatTensor(1, *input_res) | |
_ = flops_model(batch) | |
if print_per_layer_stat: | |
print_model_with_flops(flops_model) | |
flops_count = flops_model.compute_average_flops_cost() | |
params_count = get_model_parameters_number(flops_model) | |
flops_model.stop_flops_count() | |
if as_strings: | |
return flops_to_string(flops_count), params_to_string(params_count) | |
return flops_count, params_count | |
def flops_to_string(flops, units='GMac', precision=2): | |
if units is None: | |
if flops // 10**9 > 0: | |
return str(round(flops / 10.**9, precision)) + ' GMac' | |
elif flops // 10**6 > 0: | |
return str(round(flops / 10.**6, precision)) + ' MMac' | |
elif flops // 10**3 > 0: | |
return str(round(flops / 10.**3, precision)) + ' KMac' | |
else: | |
return str(flops) + ' Mac' | |
else: | |
if units == 'GMac': | |
return str(round(flops / 10.**9, precision)) + ' ' + units | |
elif units == 'MMac': | |
return str(round(flops / 10.**6, precision)) + ' ' + units | |
elif units == 'KMac': | |
return str(round(flops / 10.**3, precision)) + ' ' + units | |
else: | |
return str(flops) + ' Mac' | |
def params_to_string(params_num): | |
if params_num // 10 ** 6 > 0: | |
return str(round(params_num / 10 ** 6, 2)) + ' M' | |
elif params_num // 10 ** 3: | |
return str(round(params_num / 10 ** 3, 2)) + ' k' | |
else: | |
return str(params_num) | |
def print_model_with_flops(model, units='GMac', precision=3): | |
total_flops = model.compute_average_flops_cost() | |
def accumulate_flops(self): | |
if is_supported_instance(self): | |
return self.__flops__ / model.__batch_counter__ | |
else: | |
sum = 0 | |
for m in self.children(): | |
sum += m.accumulate_flops() | |
return sum | |
def flops_repr(self): | |
accumulated_flops_cost = self.accumulate_flops() | |
return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), | |
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops), | |
self.original_extra_repr()]) | |
def add_extra_repr(m): | |
m.accumulate_flops = accumulate_flops.__get__(m) | |
flops_extra_repr = flops_repr.__get__(m) | |
if m.extra_repr != flops_extra_repr: | |
m.original_extra_repr = m.extra_repr | |
m.extra_repr = flops_extra_repr | |
assert m.extra_repr != m.original_extra_repr | |
def del_extra_repr(m): | |
if hasattr(m, 'original_extra_repr'): | |
m.extra_repr = m.original_extra_repr | |
del m.original_extra_repr | |
if hasattr(m, 'accumulate_flops'): | |
del m.accumulate_flops | |
model.apply(add_extra_repr) | |
print(model) | |
model.apply(del_extra_repr) | |
def get_model_parameters_number(model): | |
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
return params_num | |
def add_flops_counting_methods(net_main_module): | |
# adding additional methods to the existing module object, | |
# this is done this way so that each function has access to self object | |
# embed() | |
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) | |
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) | |
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) | |
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) | |
net_main_module.reset_flops_count() | |
return net_main_module | |
def compute_average_flops_cost(self): | |
""" | |
A method that will be available after add_flops_counting_methods() is called | |
on a desired net object. | |
Returns current mean flops consumption per image. | |
""" | |
flops_sum = 0 | |
for module in self.modules(): | |
if is_supported_instance(module): | |
flops_sum += module.__flops__ | |
return flops_sum | |
def start_flops_count(self): | |
""" | |
A method that will be available after add_flops_counting_methods() is called | |
on a desired net object. | |
Activates the computation of mean flops consumption per image. | |
Call it before you run the network. | |
""" | |
self.apply(add_flops_counter_hook_function) | |
def stop_flops_count(self): | |
""" | |
A method that will be available after add_flops_counting_methods() is called | |
on a desired net object. | |
Stops computing the mean flops consumption per image. | |
Call whenever you want to pause the computation. | |
""" | |
self.apply(remove_flops_counter_hook_function) | |
def reset_flops_count(self): | |
""" | |
A method that will be available after add_flops_counting_methods() is called | |
on a desired net object. | |
Resets statistics computed so far. | |
""" | |
self.apply(add_flops_counter_variable_or_reset) | |
def add_flops_counter_hook_function(module): | |
if is_supported_instance(module): | |
if hasattr(module, '__flops_handle__'): | |
return | |
if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): | |
handle = module.register_forward_hook(conv_flops_counter_hook) | |
elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)): | |
handle = module.register_forward_hook(relu_flops_counter_hook) | |
elif isinstance(module, nn.Linear): | |
handle = module.register_forward_hook(linear_flops_counter_hook) | |
elif isinstance(module, (nn.BatchNorm2d)): | |
handle = module.register_forward_hook(bn_flops_counter_hook) | |
else: | |
handle = module.register_forward_hook(empty_flops_counter_hook) | |
module.__flops_handle__ = handle | |
def remove_flops_counter_hook_function(module): | |
if is_supported_instance(module): | |
if hasattr(module, '__flops_handle__'): | |
module.__flops_handle__.remove() | |
del module.__flops_handle__ | |
def add_flops_counter_variable_or_reset(module): | |
if is_supported_instance(module): | |
module.__flops__ = 0 | |
# ---- Internal functions | |
def is_supported_instance(module): | |
if isinstance(module, | |
( | |
nn.Conv2d, nn.ConvTranspose2d, | |
nn.BatchNorm2d, | |
nn.Linear, | |
nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, | |
)): | |
return True | |
return False | |
def conv_flops_counter_hook(conv_module, input, output): | |
# Can have multiple inputs, getting the first one | |
# input = input[0] | |
batch_size = output.shape[0] | |
output_dims = list(output.shape[2:]) | |
kernel_dims = list(conv_module.kernel_size) | |
in_channels = conv_module.in_channels | |
out_channels = conv_module.out_channels | |
groups = conv_module.groups | |
filters_per_channel = out_channels // groups | |
conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel | |
active_elements_count = batch_size * np.prod(output_dims) | |
overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count) | |
# overall_flops = overall_conv_flops | |
conv_module.__flops__ += int(overall_conv_flops) | |
# conv_module.__output_dims__ = output_dims | |
def relu_flops_counter_hook(module, input, output): | |
active_elements_count = output.numel() | |
module.__flops__ += int(active_elements_count) | |
# print(module.__flops__, id(module)) | |
# print(module) | |
def linear_flops_counter_hook(module, input, output): | |
input = input[0] | |
if len(input.shape) == 1: | |
batch_size = 1 | |
module.__flops__ += int(batch_size * input.shape[0] * output.shape[0]) | |
else: | |
batch_size = input.shape[0] | |
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) | |
def bn_flops_counter_hook(module, input, output): | |
# input = input[0] | |
# TODO: need to check here | |
# batch_flops = np.prod(input.shape) | |
# if module.affine: | |
# batch_flops *= 2 | |
# module.__flops__ += int(batch_flops) | |
batch = output.shape[0] | |
output_dims = output.shape[2:] | |
channels = module.num_features | |
batch_flops = batch * channels * np.prod(output_dims) | |
if module.affine: | |
batch_flops *= 2 | |
module.__flops__ += int(batch_flops) | |
# ---- Count the number of convolutional layers and the activation | |
def add_activation_counting_methods(net_main_module): | |
# adding additional methods to the existing module object, | |
# this is done this way so that each function has access to self object | |
# embed() | |
net_main_module.start_activation_count = start_activation_count.__get__(net_main_module) | |
net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module) | |
net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module) | |
net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module) | |
net_main_module.reset_activation_count() | |
return net_main_module | |
def compute_average_activation_cost(self): | |
""" | |
A method that will be available after add_activation_counting_methods() is called | |
on a desired net object. | |
Returns current mean activation consumption per image. | |
""" | |
activation_sum = 0 | |
num_conv = 0 | |
for module in self.modules(): | |
if is_supported_instance_for_activation(module): | |
activation_sum += module.__activation__ | |
num_conv += module.__num_conv__ | |
return activation_sum, num_conv | |
def start_activation_count(self): | |
""" | |
A method that will be available after add_activation_counting_methods() is called | |
on a desired net object. | |
Activates the computation of mean activation consumption per image. | |
Call it before you run the network. | |
""" | |
self.apply(add_activation_counter_hook_function) | |
def stop_activation_count(self): | |
""" | |
A method that will be available after add_activation_counting_methods() is called | |
on a desired net object. | |
Stops computing the mean activation consumption per image. | |
Call whenever you want to pause the computation. | |
""" | |
self.apply(remove_activation_counter_hook_function) | |
def reset_activation_count(self): | |
""" | |
A method that will be available after add_activation_counting_methods() is called | |
on a desired net object. | |
Resets statistics computed so far. | |
""" | |
self.apply(add_activation_counter_variable_or_reset) | |
def add_activation_counter_hook_function(module): | |
if is_supported_instance_for_activation(module): | |
if hasattr(module, '__activation_handle__'): | |
return | |
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): | |
handle = module.register_forward_hook(conv_activation_counter_hook) | |
module.__activation_handle__ = handle | |
def remove_activation_counter_hook_function(module): | |
if is_supported_instance_for_activation(module): | |
if hasattr(module, '__activation_handle__'): | |
module.__activation_handle__.remove() | |
del module.__activation_handle__ | |
def add_activation_counter_variable_or_reset(module): | |
if is_supported_instance_for_activation(module): | |
module.__activation__ = 0 | |
module.__num_conv__ = 0 | |
def is_supported_instance_for_activation(module): | |
if isinstance(module, | |
( | |
nn.Conv2d, nn.ConvTranspose2d, | |
)): | |
return True | |
return False | |
def conv_activation_counter_hook(module, input, output): | |
""" | |
Calculate the activations in the convolutional operation. | |
Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces. | |
:param module: | |
:param input: | |
:param output: | |
:return: | |
""" | |
module.__activation__ += output.numel() | |
module.__num_conv__ += 1 | |
def empty_flops_counter_hook(module, input, output): | |
module.__flops__ += 0 | |
def upsample_flops_counter_hook(module, input, output): | |
output_size = output[0] | |
batch_size = output_size.shape[0] | |
output_elements_count = batch_size | |
for val in output_size.shape[1:]: | |
output_elements_count *= val | |
module.__flops__ += int(output_elements_count) | |
def pool_flops_counter_hook(module, input, output): | |
input = input[0] | |
module.__flops__ += int(np.prod(input.shape)) | |
def dconv_flops_counter_hook(dconv_module, input, output): | |
input = input[0] | |
batch_size = input.shape[0] | |
output_dims = list(output.shape[2:]) | |
m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape | |
out_channels, _, kernel_dim2, _, = dconv_module.projection.shape | |
# groups = dconv_module.groups | |
# filters_per_channel = out_channels // groups | |
conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels | |
conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels | |
active_elements_count = batch_size * np.prod(output_dims) | |
overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count | |
overall_flops = overall_conv_flops | |
dconv_module.__flops__ += int(overall_flops) | |
# dconv_module.__output_dims__ = output_dims | |