import torch.nn as nn import torch import numpy as np ''' ---- 1) FLOPs: floating point operations ---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs ---- 3) #Conv2d: the number of ‘Conv2d’ layers # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 21/July/2020 # -------------------------------------------- # Reference https://github.com/sovrasov/flops-counter.pytorch.git # If you use this code, please consider the following citation: @inproceedings{zhang2020aim, % title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results}, author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others}, booktitle={European Conference on Computer Vision Workshops}, year={2020} } # -------------------------------------------- ''' def get_model_flops(model, input_res, print_per_layer_stat=True, input_constructor=None): assert type(input_res) is tuple, 'Please provide the size of the input image.' assert len(input_res) >= 3, 'Input image should have 3 dimensions.' flops_model = add_flops_counting_methods(model) flops_model.eval().start_flops_count() if input_constructor: input = input_constructor(input_res) _ = flops_model(**input) else: device = list(flops_model.parameters())[-1].device batch = torch.FloatTensor(1, *input_res).to(device) _ = flops_model(batch) if print_per_layer_stat: print_model_with_flops(flops_model) flops_count = flops_model.compute_average_flops_cost() flops_model.stop_flops_count() return flops_count def get_model_activation(model, input_res, input_constructor=None): assert type(input_res) is tuple, 'Please provide the size of the input image.' assert len(input_res) >= 3, 'Input image should have 3 dimensions.' activation_model = add_activation_counting_methods(model) activation_model.eval().start_activation_count() if input_constructor: input = input_constructor(input_res) _ = activation_model(**input) else: device = list(activation_model.parameters())[-1].device batch = torch.FloatTensor(1, *input_res).to(device) _ = activation_model(batch) activation_count, num_conv = activation_model.compute_average_activation_cost() activation_model.stop_activation_count() return activation_count, num_conv def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True, input_constructor=None): assert type(input_res) is tuple assert len(input_res) >= 3 flops_model = add_flops_counting_methods(model) flops_model.eval().start_flops_count() if input_constructor: input = input_constructor(input_res) _ = flops_model(**input) else: batch = torch.FloatTensor(1, *input_res) _ = flops_model(batch) if print_per_layer_stat: print_model_with_flops(flops_model) flops_count = flops_model.compute_average_flops_cost() params_count = get_model_parameters_number(flops_model) flops_model.stop_flops_count() if as_strings: return flops_to_string(flops_count), params_to_string(params_count) return flops_count, params_count def flops_to_string(flops, units='GMac', precision=2): if units is None: if flops // 10**9 > 0: return str(round(flops / 10.**9, precision)) + ' GMac' elif flops // 10**6 > 0: return str(round(flops / 10.**6, precision)) + ' MMac' elif flops // 10**3 > 0: return str(round(flops / 10.**3, precision)) + ' KMac' else: return str(flops) + ' Mac' else: if units == 'GMac': return str(round(flops / 10.**9, precision)) + ' ' + units elif units == 'MMac': return str(round(flops / 10.**6, precision)) + ' ' + units elif units == 'KMac': return str(round(flops / 10.**3, precision)) + ' ' + units else: return str(flops) + ' Mac' def params_to_string(params_num): if params_num // 10 ** 6 > 0: return str(round(params_num / 10 ** 6, 2)) + ' M' elif params_num // 10 ** 3: return str(round(params_num / 10 ** 3, 2)) + ' k' else: return str(params_num) def print_model_with_flops(model, units='GMac', precision=3): total_flops = model.compute_average_flops_cost() def accumulate_flops(self): if is_supported_instance(self): return self.__flops__ / model.__batch_counter__ else: sum = 0 for m in self.children(): sum += m.accumulate_flops() return sum def flops_repr(self): accumulated_flops_cost = self.accumulate_flops() return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), self.original_extra_repr()]) def add_extra_repr(m): m.accumulate_flops = accumulate_flops.__get__(m) flops_extra_repr = flops_repr.__get__(m) if m.extra_repr != flops_extra_repr: m.original_extra_repr = m.extra_repr m.extra_repr = flops_extra_repr assert m.extra_repr != m.original_extra_repr def del_extra_repr(m): if hasattr(m, 'original_extra_repr'): m.extra_repr = m.original_extra_repr del m.original_extra_repr if hasattr(m, 'accumulate_flops'): del m.accumulate_flops model.apply(add_extra_repr) print(model) model.apply(del_extra_repr) def get_model_parameters_number(model): params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) return params_num def add_flops_counting_methods(net_main_module): # adding additional methods to the existing module object, # this is done this way so that each function has access to self object # embed() net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) net_main_module.reset_flops_count() return net_main_module def compute_average_flops_cost(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Returns current mean flops consumption per image. """ flops_sum = 0 for module in self.modules(): if is_supported_instance(module): flops_sum += module.__flops__ return flops_sum def start_flops_count(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Activates the computation of mean flops consumption per image. Call it before you run the network. """ self.apply(add_flops_counter_hook_function) def stop_flops_count(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Stops computing the mean flops consumption per image. Call whenever you want to pause the computation. """ self.apply(remove_flops_counter_hook_function) def reset_flops_count(self): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Resets statistics computed so far. """ self.apply(add_flops_counter_variable_or_reset) def add_flops_counter_hook_function(module): if is_supported_instance(module): if hasattr(module, '__flops_handle__'): return if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): handle = module.register_forward_hook(conv_flops_counter_hook) elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)): handle = module.register_forward_hook(relu_flops_counter_hook) elif isinstance(module, nn.Linear): handle = module.register_forward_hook(linear_flops_counter_hook) elif isinstance(module, (nn.BatchNorm2d)): handle = module.register_forward_hook(bn_flops_counter_hook) else: handle = module.register_forward_hook(empty_flops_counter_hook) module.__flops_handle__ = handle def remove_flops_counter_hook_function(module): if is_supported_instance(module): if hasattr(module, '__flops_handle__'): module.__flops_handle__.remove() del module.__flops_handle__ def add_flops_counter_variable_or_reset(module): if is_supported_instance(module): module.__flops__ = 0 # ---- Internal functions def is_supported_instance(module): if isinstance(module, ( nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.Linear, nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, )): return True return False def conv_flops_counter_hook(conv_module, input, output): # Can have multiple inputs, getting the first one # input = input[0] batch_size = output.shape[0] output_dims = list(output.shape[2:]) kernel_dims = list(conv_module.kernel_size) in_channels = conv_module.in_channels out_channels = conv_module.out_channels groups = conv_module.groups filters_per_channel = out_channels // groups conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel active_elements_count = batch_size * np.prod(output_dims) overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count) # overall_flops = overall_conv_flops conv_module.__flops__ += int(overall_conv_flops) # conv_module.__output_dims__ = output_dims def relu_flops_counter_hook(module, input, output): active_elements_count = output.numel() module.__flops__ += int(active_elements_count) # print(module.__flops__, id(module)) # print(module) def linear_flops_counter_hook(module, input, output): input = input[0] if len(input.shape) == 1: batch_size = 1 module.__flops__ += int(batch_size * input.shape[0] * output.shape[0]) else: batch_size = input.shape[0] module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) def bn_flops_counter_hook(module, input, output): # input = input[0] # TODO: need to check here # batch_flops = np.prod(input.shape) # if module.affine: # batch_flops *= 2 # module.__flops__ += int(batch_flops) batch = output.shape[0] output_dims = output.shape[2:] channels = module.num_features batch_flops = batch * channels * np.prod(output_dims) if module.affine: batch_flops *= 2 module.__flops__ += int(batch_flops) # ---- Count the number of convolutional layers and the activation def add_activation_counting_methods(net_main_module): # adding additional methods to the existing module object, # this is done this way so that each function has access to self object # embed() net_main_module.start_activation_count = start_activation_count.__get__(net_main_module) net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module) net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module) net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module) net_main_module.reset_activation_count() return net_main_module def compute_average_activation_cost(self): """ A method that will be available after add_activation_counting_methods() is called on a desired net object. Returns current mean activation consumption per image. """ activation_sum = 0 num_conv = 0 for module in self.modules(): if is_supported_instance_for_activation(module): activation_sum += module.__activation__ num_conv += module.__num_conv__ return activation_sum, num_conv def start_activation_count(self): """ A method that will be available after add_activation_counting_methods() is called on a desired net object. Activates the computation of mean activation consumption per image. Call it before you run the network. """ self.apply(add_activation_counter_hook_function) def stop_activation_count(self): """ A method that will be available after add_activation_counting_methods() is called on a desired net object. Stops computing the mean activation consumption per image. Call whenever you want to pause the computation. """ self.apply(remove_activation_counter_hook_function) def reset_activation_count(self): """ A method that will be available after add_activation_counting_methods() is called on a desired net object. Resets statistics computed so far. """ self.apply(add_activation_counter_variable_or_reset) def add_activation_counter_hook_function(module): if is_supported_instance_for_activation(module): if hasattr(module, '__activation_handle__'): return if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): handle = module.register_forward_hook(conv_activation_counter_hook) module.__activation_handle__ = handle def remove_activation_counter_hook_function(module): if is_supported_instance_for_activation(module): if hasattr(module, '__activation_handle__'): module.__activation_handle__.remove() del module.__activation_handle__ def add_activation_counter_variable_or_reset(module): if is_supported_instance_for_activation(module): module.__activation__ = 0 module.__num_conv__ = 0 def is_supported_instance_for_activation(module): if isinstance(module, ( nn.Conv2d, nn.ConvTranspose2d, )): return True return False def conv_activation_counter_hook(module, input, output): """ Calculate the activations in the convolutional operation. Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces. :param module: :param input: :param output: :return: """ module.__activation__ += output.numel() module.__num_conv__ += 1 def empty_flops_counter_hook(module, input, output): module.__flops__ += 0 def upsample_flops_counter_hook(module, input, output): output_size = output[0] batch_size = output_size.shape[0] output_elements_count = batch_size for val in output_size.shape[1:]: output_elements_count *= val module.__flops__ += int(output_elements_count) def pool_flops_counter_hook(module, input, output): input = input[0] module.__flops__ += int(np.prod(input.shape)) def dconv_flops_counter_hook(dconv_module, input, output): input = input[0] batch_size = input.shape[0] output_dims = list(output.shape[2:]) m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape out_channels, _, kernel_dim2, _, = dconv_module.projection.shape # groups = dconv_module.groups # filters_per_channel = out_channels // groups conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels active_elements_count = batch_size * np.prod(output_dims) overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count overall_flops = overall_conv_flops dconv_module.__flops__ += int(overall_flops) # dconv_module.__output_dims__ = output_dims