Spaces:
Sleeping
Sleeping
r"""Pruning methods.""" | |
import numbers | |
from abc import ABC, abstractmethod | |
from collections.abc import Iterable | |
from typing import Tuple | |
import torch | |
class BasePruningMethod(ABC): | |
r"""Abstract base class for creation of new pruning techniques. | |
Provides a skeleton for customization requiring the overriding of methods | |
such as :meth:`compute_mask` and :meth:`apply`. | |
""" | |
_tensor_name: str | |
def __call__(self, module, inputs): | |
r"""Multiply the mask into original tensor and store the result. | |
Multiplies the mask (stored in ``module[name + '_mask']``) | |
into the original tensor (stored in ``module[name + '_orig']``) | |
and stores the result into ``module[name]`` by using :meth:`apply_mask`. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
inputs: not used. | |
""" | |
setattr(module, self._tensor_name, self.apply_mask(module)) | |
def compute_mask(self, t, default_mask): | |
r"""Compute and returns a mask for the input tensor ``t``. | |
Starting from a base ``default_mask`` (which should be a mask of ones | |
if the tensor has not been pruned yet), generate a random mask to | |
apply on top of the ``default_mask`` according to the specific pruning | |
method recipe. | |
Args: | |
t (torch.Tensor): tensor representing the importance scores of the | |
parameter to prune. | |
default_mask (torch.Tensor): Base mask from previous pruning | |
iterations, that need to be respected after the new mask is | |
applied. Same dims as ``t``. | |
Returns: | |
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` | |
""" | |
pass | |
def apply_mask(self, module): | |
r"""Simply handles the multiplication between the parameter being pruned and the generated mask. | |
Fetches the mask and the original tensor from the module | |
and returns the pruned version of the tensor. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
Returns: | |
pruned_tensor (torch.Tensor): pruned version of the input tensor | |
""" | |
# to carry out the multiplication, the mask needs to have been computed, | |
# so the pruning method must know what tensor it's operating on | |
assert self._tensor_name is not None, f"Module {module} has to be pruned" # this gets set in apply() | |
mask = getattr(module, self._tensor_name + "_mask") | |
orig = getattr(module, self._tensor_name + "_orig") | |
pruned_tensor = mask.to(dtype=orig.dtype) * orig | |
return pruned_tensor | |
def apply(cls, module, name, *args, importance_scores=None, **kwargs): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
args: arguments passed on to a subclass of | |
:class:`BasePruningMethod` | |
importance_scores (torch.Tensor): tensor of importance scores (of | |
same shape as module parameter) used to compute mask for pruning. | |
The values in this tensor indicate the importance of the | |
corresponding elements in the parameter being pruned. | |
If unspecified or None, the parameter will be used in its place. | |
kwargs: keyword arguments passed on to a subclass of a | |
:class:`BasePruningMethod` | |
""" | |
def _get_composite_method(cls, module, name, *args, **kwargs): | |
# Check if a pruning method has already been applied to | |
# `module[name]`. If so, store that in `old_method`. | |
old_method = None | |
found = 0 | |
# there should technically be only 1 hook with hook.name == name | |
# assert this using `found` | |
hooks_to_remove = [] | |
for k, hook in module._forward_pre_hooks.items(): | |
# if it exists, take existing thing, remove hook, then | |
# go through normal thing | |
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: | |
old_method = hook | |
hooks_to_remove.append(k) | |
found += 1 | |
assert ( | |
found <= 1 | |
), f"Avoid adding multiple pruning hooks to the\ | |
same tensor {name} of module {module}. Use a PruningContainer." | |
for k in hooks_to_remove: | |
del module._forward_pre_hooks[k] | |
# Apply the new pruning method, either from scratch or on top of | |
# the previous one. | |
method = cls(*args, **kwargs) # new pruning | |
# Have the pruning method remember what tensor it's been applied to | |
method._tensor_name = name | |
# combine `methods` with `old_method`, if `old_method` exists | |
if old_method is not None: # meaning that there was a hook | |
# if the hook is already a pruning container, just add the | |
# new pruning method to the container | |
if isinstance(old_method, PruningContainer): | |
old_method.add_pruning_method(method) | |
method = old_method # rename old_method --> method | |
# if the hook is simply a single pruning method, create a | |
# container, add the old pruning method and the new one | |
elif isinstance(old_method, BasePruningMethod): | |
container = PruningContainer(old_method) | |
# Have the pruning method remember the name of its tensor | |
# setattr(container, '_tensor_name', name) | |
container.add_pruning_method(method) | |
method = container # rename container --> method | |
return method | |
method = _get_composite_method(cls, module, name, *args, **kwargs) | |
# at this point we have no forward_pre_hooks but we could have an | |
# active reparametrization of the tensor if another pruning method | |
# had been applied (in which case `method` would be a PruningContainer | |
# and not a simple pruning method). | |
# Pruning is to be applied to the module's tensor named `name`, | |
# starting from the state it is found in prior to this iteration of | |
# pruning. The pruning mask is calculated based on importances scores. | |
orig = getattr(module, name) | |
if importance_scores is not None: | |
assert ( | |
importance_scores.shape == orig.shape | |
), f"importance_scores should have the same shape as parameter {name} of {module}" | |
else: | |
importance_scores = orig | |
# If this is the first time pruning is applied, take care of moving | |
# the original tensor to a new parameter called name + '_orig' and | |
# and deleting the original parameter | |
if not isinstance(method, PruningContainer): | |
# copy `module[name]` to `module[name + '_orig']` | |
module.register_parameter(name + "_orig", orig) | |
# temporarily delete `module[name]` | |
del module._parameters[name] | |
default_mask = torch.ones_like(orig) # temp | |
# If this is not the first time pruning is applied, all of the above | |
# has been done before in a previous pruning iteration, so we're good | |
# to go | |
else: | |
default_mask = ( | |
getattr(module, name + "_mask") | |
.detach() | |
.clone(memory_format=torch.contiguous_format) | |
) | |
# Use try/except because if anything goes wrong with the mask | |
# computation etc., you'd want to roll back. | |
try: | |
# get the final mask, computed according to the specific method | |
mask = method.compute_mask(importance_scores, default_mask=default_mask) | |
# reparameterize by saving mask to `module[name + '_mask']`... | |
module.register_buffer(name + "_mask", mask) | |
# ... and the new pruned tensor to `module[name]` | |
setattr(module, name, method.apply_mask(module)) | |
# associate the pruning method to the module via a hook to | |
# compute the function before every forward() (compile by run) | |
module.register_forward_pre_hook(method) | |
except Exception as e: | |
if not isinstance(method, PruningContainer): | |
orig = getattr(module, name + "_orig") | |
module.register_parameter(name, orig) | |
del module._parameters[name + "_orig"] | |
raise e | |
return method | |
def prune(self, t, default_mask=None, importance_scores=None): | |
r"""Compute and returns a pruned version of input tensor ``t``. | |
According to the pruning rule specified in :meth:`compute_mask`. | |
Args: | |
t (torch.Tensor): tensor to prune (of same dimensions as | |
``default_mask``). | |
importance_scores (torch.Tensor): tensor of importance scores (of | |
same shape as ``t``) used to compute mask for pruning ``t``. | |
The values in this tensor indicate the importance of the | |
corresponding elements in the ``t`` that is being pruned. | |
If unspecified or None, the tensor ``t`` will be used in its place. | |
default_mask (torch.Tensor, optional): mask from previous pruning | |
iteration, if any. To be considered when determining what | |
portion of the tensor that pruning should act on. If None, | |
default to a mask of ones. | |
Returns: | |
pruned version of tensor ``t``. | |
""" | |
if importance_scores is not None: | |
assert ( | |
importance_scores.shape == t.shape | |
), "importance_scores should have the same shape as tensor t" | |
else: | |
importance_scores = t | |
default_mask = default_mask if default_mask is not None else torch.ones_like(t) | |
return t * self.compute_mask(importance_scores, default_mask=default_mask) | |
def remove(self, module): | |
r"""Remove the pruning reparameterization from a module. | |
The pruned parameter named ``name`` remains permanently pruned, | |
and the parameter named ``name+'_orig'`` is removed from the parameter list. | |
Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. | |
Note: | |
Pruning itself is NOT undone or reversed! | |
""" | |
# before removing pruning from a tensor, it has to have been applied | |
assert ( | |
self._tensor_name is not None | |
), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply() | |
# to update module[name] to latest trained weights | |
weight = self.apply_mask(module) # masked weights | |
# delete and reset | |
if hasattr(module, self._tensor_name): | |
delattr(module, self._tensor_name) | |
orig = module._parameters[self._tensor_name + "_orig"] | |
orig.data = weight.data | |
del module._parameters[self._tensor_name + "_orig"] | |
del module._buffers[self._tensor_name + "_mask"] | |
setattr(module, self._tensor_name, orig) | |
class PruningContainer(BasePruningMethod): | |
"""Container holding a sequence of pruning methods for iterative pruning. | |
Keeps track of the order in which pruning methods are applied and handles | |
combining successive pruning calls. | |
Accepts as argument an instance of a BasePruningMethod or an iterable of | |
them. | |
""" | |
def __init__(self, *args): | |
self._pruning_methods: Tuple[BasePruningMethod, ...] = tuple() | |
if not isinstance(args, Iterable): # only 1 item | |
self._tensor_name = args._tensor_name | |
self.add_pruning_method(args) | |
elif len(args) == 1: # only 1 item in a tuple | |
self._tensor_name = args[0]._tensor_name | |
self.add_pruning_method(args[0]) | |
else: # manual construction from list or other iterable (or no args) | |
for method in args: | |
self.add_pruning_method(method) | |
def add_pruning_method(self, method): | |
r"""Add a child pruning ``method`` to the container. | |
Args: | |
method (subclass of BasePruningMethod): child pruning method | |
to be added to the container. | |
""" | |
# check that we're adding a pruning method to the container | |
if not isinstance(method, BasePruningMethod) and method is not None: | |
raise TypeError( | |
f"{type(method)} is not a BasePruningMethod subclass" | |
) | |
elif method is not None and self._tensor_name != method._tensor_name: | |
raise ValueError( | |
"Can only add pruning methods acting on " | |
f"the parameter named '{self._tensor_name}' to PruningContainer {self}." | |
+ f" Found '{method._tensor_name}'" | |
) | |
# if all checks passed, add to _pruning_methods tuple | |
self._pruning_methods += (method,) # type: ignore[operator] | |
def __len__(self): | |
return len(self._pruning_methods) | |
def __iter__(self): | |
return iter(self._pruning_methods) | |
def __getitem__(self, idx): | |
return self._pruning_methods[idx] | |
def compute_mask(self, t, default_mask): | |
r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. | |
The new partial mask should be computed on the entries or channels | |
that were not zeroed out by the ``default_mask``. | |
Which portions of the tensor ``t`` the new mask will be calculated from | |
depends on the ``PRUNING_TYPE`` (handled by the type handler): | |
* for 'unstructured', the mask will be computed from the raveled | |
list of nonmasked entries; | |
* for 'structured', the mask will be computed from the nonmasked | |
channels in the tensor; | |
* for 'global', the mask will be computed across all entries. | |
Args: | |
t (torch.Tensor): tensor representing the parameter to prune | |
(of same dimensions as ``default_mask``). | |
default_mask (torch.Tensor): mask from previous pruning iteration. | |
Returns: | |
mask (torch.Tensor): new mask that combines the effects | |
of the ``default_mask`` and the new mask from the current | |
pruning ``method`` (of same dimensions as ``default_mask`` and | |
``t``). | |
""" | |
def _combine_masks(method, t, mask): | |
r"""Combine the masks from all pruning methods and returns a new mask. | |
Args: | |
method (a BasePruningMethod subclass): pruning method | |
currently being applied. | |
t (torch.Tensor): tensor representing the parameter to prune | |
(of same dimensions as mask). | |
mask (torch.Tensor): mask from previous pruning iteration | |
Returns: | |
new_mask (torch.Tensor): new mask that combines the effects | |
of the old mask and the new mask from the current | |
pruning method (of same dimensions as mask and t). | |
""" | |
new_mask = mask # start off from existing mask | |
new_mask = new_mask.to(dtype=t.dtype) | |
# compute a slice of t onto which the new pruning method will operate | |
if method.PRUNING_TYPE == "unstructured": | |
# prune entries of t where the mask is 1 | |
slc = mask == 1 | |
# for struct pruning, exclude channels that have already been | |
# entirely pruned | |
elif method.PRUNING_TYPE == "structured": | |
if not hasattr(method, "dim"): | |
raise AttributeError( | |
"Pruning methods of PRUNING_TYPE " | |
'"structured" need to have the attribute `dim` defined.' | |
) | |
# find the channels to keep by removing the ones that have been | |
# zeroed out already (i.e. where sum(entries) == 0) | |
n_dims = t.dim() # "is this a 2D tensor? 3D? ..." | |
dim = method.dim | |
# convert negative indexing | |
if dim < 0: | |
dim = n_dims + dim | |
# if dim is still negative after subtracting it from n_dims | |
if dim < 0: | |
raise IndexError( | |
f"Index is out of bounds for tensor with dimensions {n_dims}" | |
) | |
# find channels along dim = dim that aren't already tots 0ed out | |
keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 | |
# create slice to identify what to prune | |
slc = [slice(None)] * n_dims | |
slc[dim] = keep_channel | |
elif method.PRUNING_TYPE == "global": | |
n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." | |
slc = [slice(None)] * n_dims | |
else: | |
raise ValueError( | |
f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}" | |
) | |
# compute the new mask on the unpruned slice of the tensor t | |
partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) | |
new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) | |
return new_mask | |
method = self._pruning_methods[-1] | |
mask = _combine_masks(method, t, default_mask) | |
return mask | |
class Identity(BasePruningMethod): | |
r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.""" | |
PRUNING_TYPE = "unstructured" | |
def compute_mask(self, t, default_mask): | |
mask = default_mask | |
return mask | |
def apply(cls, module, name): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
""" | |
return super().apply(module, name) | |
class RandomUnstructured(BasePruningMethod): | |
r"""Prune (currently unpruned) units in a tensor at random. | |
Args: | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
""" | |
PRUNING_TYPE = "unstructured" | |
def __init__(self, amount): | |
# Check range of validity of pruning amount | |
_validate_pruning_amount_init(amount) | |
self.amount = amount | |
def compute_mask(self, t, default_mask): | |
# Check that the amount of units to prune is not > than the number of | |
# parameters in t | |
tensor_size = t.nelement() | |
# Compute number of units to prune: amount if int, | |
# else amount * tensor_size | |
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) | |
# This should raise an error if the number of units to prune is larger | |
# than the number of units in the tensor | |
_validate_pruning_amount(nparams_toprune, tensor_size) | |
mask = default_mask.clone(memory_format=torch.contiguous_format) | |
if nparams_toprune != 0: # k=0 not supported by torch.kthvalue | |
prob = torch.rand_like(t) | |
topk = torch.topk(prob.view(-1), k=nparams_toprune) | |
mask.view(-1)[topk.indices] = 0 | |
return mask | |
def apply(cls, module, name, amount): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
""" | |
return super().apply(module, name, amount=amount) | |
class L1Unstructured(BasePruningMethod): | |
r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. | |
Args: | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
""" | |
PRUNING_TYPE = "unstructured" | |
def __init__(self, amount): | |
# Check range of validity of pruning amount | |
_validate_pruning_amount_init(amount) | |
self.amount = amount | |
def compute_mask(self, t, default_mask): | |
# Check that the amount of units to prune is not > than the number of | |
# parameters in t | |
tensor_size = t.nelement() | |
# Compute number of units to prune: amount if int, | |
# else amount * tensor_size | |
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) | |
# This should raise an error if the number of units to prune is larger | |
# than the number of units in the tensor | |
_validate_pruning_amount(nparams_toprune, tensor_size) | |
mask = default_mask.clone(memory_format=torch.contiguous_format) | |
if nparams_toprune != 0: # k=0 not supported by torch.kthvalue | |
# largest=True --> top k; largest=False --> bottom k | |
# Prune the smallest k | |
topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) | |
# topk will have .indices and .values | |
mask.view(-1)[topk.indices] = 0 | |
return mask | |
def apply(cls, module, name, amount, importance_scores=None): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
importance_scores (torch.Tensor): tensor of importance scores (of same | |
shape as module parameter) used to compute mask for pruning. | |
The values in this tensor indicate the importance of the corresponding | |
elements in the parameter being pruned. | |
If unspecified or None, the module parameter will be used in its place. | |
""" | |
return super().apply( | |
module, name, amount=amount, importance_scores=importance_scores | |
) | |
class RandomStructured(BasePruningMethod): | |
r"""Prune entire (currently unpruned) channels in a tensor at random. | |
Args: | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
dim (int, optional): index of the dim along which we define | |
channels to prune. Default: -1. | |
""" | |
PRUNING_TYPE = "structured" | |
def __init__(self, amount, dim=-1): | |
# Check range of validity of amount | |
_validate_pruning_amount_init(amount) | |
self.amount = amount | |
self.dim = dim | |
def compute_mask(self, t, default_mask): | |
r"""Compute and returns a mask for the input tensor ``t``. | |
Starting from a base ``default_mask`` (which should be a mask of ones | |
if the tensor has not been pruned yet), generate a random mask to | |
apply on top of the ``default_mask`` by randomly zeroing out channels | |
along the specified dim of the tensor. | |
Args: | |
t (torch.Tensor): tensor representing the parameter to prune | |
default_mask (torch.Tensor): Base mask from previous pruning | |
iterations, that need to be respected after the new mask is | |
applied. Same dims as ``t``. | |
Returns: | |
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` | |
Raises: | |
IndexError: if ``self.dim >= len(t.shape)`` | |
""" | |
# Check that tensor has structure (i.e. more than 1 dimension) such | |
# that the concept of "channels" makes sense | |
_validate_structured_pruning(t) | |
# Check that self.dim is a valid dim to index t, else raise IndexError | |
_validate_pruning_dim(t, self.dim) | |
# Check that the amount of channels to prune is not > than the number of | |
# channels in t along the dim to prune | |
tensor_size = t.shape[self.dim] | |
# Compute number of units to prune: amount if int, | |
# else amount * tensor_size | |
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) | |
# This should raise an error if the number of units to prune is larger | |
# than the number of units in the tensor | |
_validate_pruning_amount(nparams_toprune, tensor_size) | |
# Compute binary mask by initializing it to all 0s and then filling in | |
# 1s wherever topk.indices indicates, along self.dim. | |
# mask has the same shape as tensor t | |
def make_mask(t, dim, nchannels, nchannels_toprune): | |
# generate a random number in [0, 1] to associate to each channel | |
prob = torch.rand(nchannels) | |
# generate mask for each channel by 0ing out the channels that | |
# got assigned the k = nchannels_toprune lowest values in prob | |
threshold = torch.kthvalue(prob, k=nchannels_toprune).values | |
channel_mask = prob > threshold | |
mask = torch.zeros_like(t) | |
slc = [slice(None)] * len(t.shape) | |
slc[dim] = channel_mask | |
mask[slc] = 1 | |
return mask | |
if nparams_toprune == 0: # k=0 not supported by torch.kthvalue | |
mask = default_mask | |
else: | |
# apply the new structured mask on top of prior (potentially | |
# unstructured) mask | |
mask = make_mask(t, self.dim, tensor_size, nparams_toprune) | |
mask *= default_mask.to(dtype=mask.dtype) | |
return mask | |
def apply(cls, module, name, amount, dim=-1): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
dim (int, optional): index of the dim along which we define | |
channels to prune. Default: -1. | |
""" | |
return super().apply(module, name, amount=amount, dim=dim) | |
class LnStructured(BasePruningMethod): | |
r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. | |
Args: | |
amount (int or float): quantity of channels to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid | |
entries for argument ``p`` in :func:`torch.norm`. | |
dim (int, optional): index of the dim along which we define | |
channels to prune. Default: -1. | |
""" | |
PRUNING_TYPE = "structured" | |
def __init__(self, amount, n, dim=-1): | |
# Check range of validity of amount | |
_validate_pruning_amount_init(amount) | |
self.amount = amount | |
self.n = n | |
self.dim = dim | |
def compute_mask(self, t, default_mask): | |
r"""Compute and returns a mask for the input tensor ``t``. | |
Starting from a base ``default_mask`` (which should be a mask of ones | |
if the tensor has not been pruned yet), generate a mask to apply on | |
top of the ``default_mask`` by zeroing out the channels along the | |
specified dim with the lowest L\ ``n``-norm. | |
Args: | |
t (torch.Tensor): tensor representing the parameter to prune | |
default_mask (torch.Tensor): Base mask from previous pruning | |
iterations, that need to be respected after the new mask is | |
applied. Same dims as ``t``. | |
Returns: | |
mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` | |
Raises: | |
IndexError: if ``self.dim >= len(t.shape)`` | |
""" | |
# Check that tensor has structure (i.e. more than 1 dimension) such | |
# that the concept of "channels" makes sense | |
_validate_structured_pruning(t) | |
# Check that self.dim is a valid dim to index t, else raise IndexError | |
_validate_pruning_dim(t, self.dim) | |
# Check that the amount of channels to prune is not > than the number of | |
# channels in t along the dim to prune | |
tensor_size = t.shape[self.dim] | |
# Compute number of units to prune: amount if int, | |
# else amount * tensor_size | |
nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) | |
nparams_tokeep = tensor_size - nparams_toprune | |
# This should raise an error if the number of units to prune is larger | |
# than the number of units in the tensor | |
_validate_pruning_amount(nparams_toprune, tensor_size) | |
# Structured pruning prunes entire channels so we need to know the | |
# L_n norm along each channel to then find the topk based on this | |
# metric | |
norm = _compute_norm(t, self.n, self.dim) | |
# largest=True --> top k; largest=False --> bottom k | |
# Keep the largest k channels along dim=self.dim | |
topk = torch.topk(norm, k=nparams_tokeep, largest=True) | |
# topk will have .indices and .values | |
# Compute binary mask by initializing it to all 0s and then filling in | |
# 1s wherever topk.indices indicates, along self.dim. | |
# mask has the same shape as tensor t | |
def make_mask(t, dim, indices): | |
# init mask to 0 | |
mask = torch.zeros_like(t) | |
# e.g.: slc = [None, None, None], if len(t.shape) = 3 | |
slc = [slice(None)] * len(t.shape) | |
# replace a None at position=dim with indices | |
# e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] | |
slc[dim] = indices | |
# use slc to slice mask and replace all its entries with 1s | |
# e.g.: mask[:, :, [0, 2, 3]] = 1 | |
mask[slc] = 1 | |
return mask | |
if nparams_toprune == 0: # k=0 not supported by torch.kthvalue | |
mask = default_mask | |
else: | |
mask = make_mask(t, self.dim, topk.indices) | |
mask *= default_mask.to(dtype=mask.dtype) | |
return mask | |
def apply(cls, module, name, amount, n, dim, importance_scores=None): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid | |
entries for argument ``p`` in :func:`torch.norm`. | |
dim (int): index of the dim along which we define channels to | |
prune. | |
importance_scores (torch.Tensor): tensor of importance scores (of same | |
shape as module parameter) used to compute mask for pruning. | |
The values in this tensor indicate the importance of the corresponding | |
elements in the parameter being pruned. | |
If unspecified or None, the module parameter will be used in its place. | |
""" | |
return super().apply( | |
module, | |
name, | |
amount=amount, | |
n=n, | |
dim=dim, | |
importance_scores=importance_scores, | |
) | |
class CustomFromMask(BasePruningMethod): | |
PRUNING_TYPE = "global" | |
def __init__(self, mask): | |
self.mask = mask | |
def compute_mask(self, t, default_mask): | |
assert default_mask.shape == self.mask.shape | |
mask = default_mask * self.mask.to(dtype=default_mask.dtype) | |
return mask | |
def apply(cls, module, name, mask): | |
r"""Add pruning on the fly and reparametrization of a tensor. | |
Adds the forward pre-hook that enables pruning on the fly and | |
the reparametrization of a tensor in terms of the original tensor | |
and the pruning mask. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
""" | |
return super().apply(module, name, mask=mask) | |
def identity(module, name): | |
r"""Apply pruning reparametrization without pruning any units. | |
Applies pruning reparametrization to the tensor corresponding to the | |
parameter called ``name`` in ``module`` without actually pruning any | |
units. Modifies module in place (and also return the modified module) | |
by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Note: | |
The mask is a tensor of ones. | |
Args: | |
module (nn.Module): module containing the tensor to prune. | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
Returns: | |
module (nn.Module): modified (i.e. pruned) version of the input module | |
Examples: | |
>>> # xdoctest: +SKIP | |
>>> m = prune.identity(nn.Linear(2, 3), 'bias') | |
>>> print(m.bias_mask) | |
tensor([1., 1., 1.]) | |
""" | |
Identity.apply(module, name) | |
return module | |
def random_unstructured(module, name, amount): | |
r"""Prune tensor by removing random (currently unpruned) units. | |
Prunes tensor corresponding to parameter called ``name`` in ``module`` | |
by removing the specified ``amount`` of (currently unpruned) units | |
selected at random. | |
Modifies module in place (and also return the modified module) by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
Returns: | |
module (nn.Module): modified (i.e. pruned) version of the input module | |
Examples: | |
>>> # xdoctest: +SKIP | |
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) | |
>>> torch.sum(m.weight_mask == 0) | |
tensor(1) | |
""" | |
RandomUnstructured.apply(module, name, amount) | |
return module | |
def l1_unstructured(module, name, amount, importance_scores=None): | |
r"""Prune tensor by removing units with the lowest L1-norm. | |
Prunes tensor corresponding to parameter called ``name`` in ``module`` | |
by removing the specified `amount` of (currently unpruned) units with the | |
lowest L1-norm. | |
Modifies module in place (and also return the modified module) | |
by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
importance_scores (torch.Tensor): tensor of importance scores (of same | |
shape as module parameter) used to compute mask for pruning. | |
The values in this tensor indicate the importance of the corresponding | |
elements in the parameter being pruned. | |
If unspecified or None, the module parameter will be used in its place. | |
Returns: | |
module (nn.Module): modified (i.e. pruned) version of the input module | |
Examples: | |
>>> # xdoctest: +SKIP | |
>>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) | |
>>> m.state_dict().keys() | |
odict_keys(['bias', 'weight_orig', 'weight_mask']) | |
""" | |
L1Unstructured.apply( | |
module, name, amount=amount, importance_scores=importance_scores | |
) | |
return module | |
def random_structured(module, name, amount, dim): | |
r"""Prune tensor by removing random channels along the specified dimension. | |
Prunes tensor corresponding to parameter called ``name`` in ``module`` | |
by removing the specified ``amount`` of (currently unpruned) channels | |
along the specified ``dim`` selected at random. | |
Modifies module in place (and also return the modified module) | |
by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
dim (int): index of the dim along which we define channels to prune. | |
Returns: | |
module (nn.Module): modified (i.e. pruned) version of the input module | |
Examples: | |
>>> # xdoctest: +SKIP | |
>>> m = prune.random_structured( | |
... nn.Linear(5, 3), 'weight', amount=3, dim=1 | |
... ) | |
>>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) | |
>>> print(columns_pruned) | |
3 | |
""" | |
RandomStructured.apply(module, name, amount, dim) | |
return module | |
def ln_structured(module, name, amount, n, dim, importance_scores=None): | |
r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension. | |
Prunes tensor corresponding to parameter called ``name`` in ``module`` | |
by removing the specified ``amount`` of (currently unpruned) channels | |
along the specified ``dim`` with the lowest L\ ``n``-norm. | |
Modifies module in place (and also return the modified module) | |
by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
amount (int or float): quantity of parameters to prune. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid | |
entries for argument ``p`` in :func:`torch.norm`. | |
dim (int): index of the dim along which we define channels to prune. | |
importance_scores (torch.Tensor): tensor of importance scores (of same | |
shape as module parameter) used to compute mask for pruning. | |
The values in this tensor indicate the importance of the corresponding | |
elements in the parameter being pruned. | |
If unspecified or None, the module parameter will be used in its place. | |
Returns: | |
module (nn.Module): modified (i.e. pruned) version of the input module | |
Examples: | |
>>> from torch.nn.utils import prune | |
>>> m = prune.ln_structured( | |
... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') | |
... ) | |
""" | |
LnStructured.apply( | |
module, name, amount, n, dim, importance_scores=importance_scores | |
) | |
return module | |
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): | |
r""" | |
Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. | |
Modifies modules in place by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Args: | |
parameters (Iterable of (module, name) tuples): parameters of | |
the model to prune in a global fashion, i.e. by aggregating all | |
weights prior to deciding which ones to prune. module must be of | |
type :class:`nn.Module`, and name must be a string. | |
pruning_method (function): a valid pruning function from this module, | |
or a custom one implemented by the user that satisfies the | |
implementation guidelines and has ``PRUNING_TYPE='unstructured'``. | |
importance_scores (dict): a dictionary mapping (module, name) tuples to | |
the corresponding parameter's importance scores tensor. The tensor | |
should be the same shape as the parameter, and is used for computing | |
mask for pruning. | |
If unspecified or None, the parameter will be used in place of its | |
importance scores. | |
kwargs: other keyword arguments such as: | |
amount (int or float): quantity of parameters to prune across the | |
specified parameters. | |
If ``float``, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If ``int``, it represents the | |
absolute number of parameters to prune. | |
Raises: | |
TypeError: if ``PRUNING_TYPE != 'unstructured'`` | |
Note: | |
Since global structured pruning doesn't make much sense unless the | |
norm is normalized by the size of the parameter, we now limit the | |
scope of global pruning to unstructured methods. | |
Examples: | |
>>> from torch.nn.utils import prune | |
>>> from collections import OrderedDict | |
>>> net = nn.Sequential(OrderedDict([ | |
... ('first', nn.Linear(10, 4)), | |
... ('second', nn.Linear(4, 1)), | |
... ])) | |
>>> parameters_to_prune = ( | |
... (net.first, 'weight'), | |
... (net.second, 'weight'), | |
... ) | |
>>> prune.global_unstructured( | |
... parameters_to_prune, | |
... pruning_method=prune.L1Unstructured, | |
... amount=10, | |
... ) | |
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) | |
tensor(10) | |
""" | |
# ensure parameters is a list or generator of tuples | |
if not isinstance(parameters, Iterable): | |
raise TypeError("global_unstructured(): parameters is not an Iterable") | |
importance_scores = importance_scores if importance_scores is not None else {} | |
if not isinstance(importance_scores, dict): | |
raise TypeError("global_unstructured(): importance_scores must be of type dict") | |
# flatten importance scores to consider them all at once in global pruning | |
relevant_importance_scores = torch.nn.utils.parameters_to_vector( | |
[ | |
importance_scores.get((module, name), getattr(module, name)) | |
for (module, name) in parameters | |
] | |
) | |
# similarly, flatten the masks (if they exist), or use a flattened vector | |
# of 1s of the same dimensions as t | |
default_mask = torch.nn.utils.parameters_to_vector( | |
[ | |
getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) | |
for (module, name) in parameters | |
] | |
) | |
# use the canonical pruning methods to compute the new mask, even if the | |
# parameter is now a flattened out version of `parameters` | |
container = PruningContainer() | |
container._tensor_name = "temp" # to make it match that of `method` | |
method = pruning_method(**kwargs) | |
method._tensor_name = "temp" # to make it match that of `container` | |
if method.PRUNING_TYPE != "unstructured": | |
raise TypeError( | |
'Only "unstructured" PRUNING_TYPE supported for ' | |
f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}" | |
) | |
container.add_pruning_method(method) | |
# use the `compute_mask` method from `PruningContainer` to combine the | |
# mask computed by the new method with the pre-existing mask | |
final_mask = container.compute_mask(relevant_importance_scores, default_mask) | |
# Pointer for slicing the mask to match the shape of each parameter | |
pointer = 0 | |
for module, name in parameters: | |
param = getattr(module, name) | |
# The length of the parameter | |
num_param = param.numel() | |
# Slice the mask, reshape it | |
param_mask = final_mask[pointer : pointer + num_param].view_as(param) | |
# Assign the correct pre-computed mask to each parameter and add it | |
# to the forward_pre_hooks like any other pruning method | |
custom_from_mask(module, name, mask=param_mask) | |
# Increment the pointer to continue slicing the final_mask | |
pointer += num_param | |
def custom_from_mask(module, name, mask): | |
r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. | |
Modifies module in place (and also return the modified module) by: | |
1) adding a named buffer called ``name+'_mask'`` corresponding to the | |
binary mask applied to the parameter ``name`` by the pruning method. | |
2) replacing the parameter ``name`` by its pruned version, while the | |
original (unpruned) parameter is stored in a new parameter named | |
``name+'_orig'``. | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
mask (Tensor): binary mask to be applied to the parameter. | |
Returns: | |
module (nn.Module): modified (i.e. pruned) version of the input module | |
Examples: | |
>>> from torch.nn.utils import prune | |
>>> m = prune.custom_from_mask( | |
... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) | |
... ) | |
>>> print(m.bias_mask) | |
tensor([0., 1., 0.]) | |
""" | |
CustomFromMask.apply(module, name, mask) | |
return module | |
def remove(module, name): | |
r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook. | |
The pruned parameter named ``name`` remains permanently pruned, and the parameter | |
named ``name+'_orig'`` is removed from the parameter list. Similarly, | |
the buffer named ``name+'_mask'`` is removed from the buffers. | |
Note: | |
Pruning itself is NOT undone or reversed! | |
Args: | |
module (nn.Module): module containing the tensor to prune | |
name (str): parameter name within ``module`` on which pruning | |
will act. | |
Examples: | |
>>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) | |
>>> m = remove(m, name='weight') | |
""" | |
for k, hook in module._forward_pre_hooks.items(): | |
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: | |
hook.remove(module) | |
del module._forward_pre_hooks[k] | |
return module | |
raise ValueError( | |
f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed" | |
) | |
def is_pruned(module): | |
r"""Check if a module is pruned by looking for pruning pre-hooks. | |
Check whether ``module`` is pruned by looking for | |
``forward_pre_hooks`` in its modules that inherit from the | |
:class:`BasePruningMethod`. | |
Args: | |
module (nn.Module): object that is either pruned or unpruned | |
Returns: | |
binary answer to whether ``module`` is pruned. | |
Examples: | |
>>> from torch.nn.utils import prune | |
>>> m = nn.Linear(5, 7) | |
>>> print(prune.is_pruned(m)) | |
False | |
>>> prune.random_unstructured(m, name='weight', amount=0.2) | |
>>> print(prune.is_pruned(m)) | |
True | |
""" | |
for _, submodule in module.named_modules(): | |
for hook in submodule._forward_pre_hooks.values(): | |
if isinstance(hook, BasePruningMethod): | |
return True | |
return False | |
def _validate_pruning_amount_init(amount): | |
r"""Validate helper to check the range of amount at init. | |
Args: | |
amount (int or float): quantity of parameters to prune. | |
If float, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If int, it represents the | |
absolute number of parameters to prune. | |
Raises: | |
ValueError: if amount is a float not in [0, 1], or if it's a negative | |
integer. | |
TypeError: if amount is neither a float nor an integer. | |
Note: | |
This does not take into account the number of parameters in the | |
tensor to be pruned, which is known only at prune. | |
""" | |
if not isinstance(amount, numbers.Real): | |
raise TypeError( | |
f"Invalid type for amount: {amount}. Must be int or float." | |
) | |
if (isinstance(amount, numbers.Integral) and amount < 0) or ( | |
not isinstance(amount, numbers.Integral) # so it's a float | |
and (float(amount) > 1.0 or float(amount) < 0.0) | |
): | |
raise ValueError( | |
f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer" | |
) | |
def _validate_pruning_amount(amount, tensor_size): | |
r"""Validate that the pruning amount is meaningful wrt to the size of the data. | |
Validation helper to check that the amount of parameters to prune | |
is meaningful wrt to the size of the data (`tensor_size`). | |
Args: | |
amount (int or float): quantity of parameters to prune. | |
If float, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If int, it represents the | |
absolute number of parameters to prune. | |
tensor_size (int): absolute number of parameters in the tensor | |
to prune. | |
""" | |
# TODO: consider removing this check and allowing users to specify | |
# a number of units to prune that is greater than the number of units | |
# left to prune. In this case, the tensor will just be fully pruned. | |
if isinstance(amount, numbers.Integral) and amount > tensor_size: | |
raise ValueError( | |
f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}" | |
) | |
def _validate_structured_pruning(t): | |
r"""Validate that the tensor to be pruned is at least 2-Dimensional. | |
Validation helper to check that the tensor to be pruned is multi- | |
dimensional, such that the concept of "channels" is well-defined. | |
Args: | |
t (torch.Tensor): tensor representing the parameter to prune | |
Raises: | |
ValueError: if the tensor `t` is not at least 2D. | |
""" | |
shape = t.shape | |
if len(shape) <= 1: | |
raise ValueError( | |
"Structured pruning can only be applied to " | |
"multidimensional tensors. Found tensor of shape " | |
f"{shape} with {len(shape)} dims" | |
) | |
def _compute_nparams_toprune(amount, tensor_size): | |
r"""Convert the pruning amount from a percentage to absolute value. | |
Since amount can be expressed either in absolute value or as a | |
percentage of the number of units/channels in a tensor, this utility | |
function converts the percentage to absolute value to standardize | |
the handling of pruning. | |
Args: | |
amount (int or float): quantity of parameters to prune. | |
If float, should be between 0.0 and 1.0 and represent the | |
fraction of parameters to prune. If int, it represents the | |
absolute number of parameters to prune. | |
tensor_size (int): absolute number of parameters in the tensor | |
to prune. | |
Returns: | |
int: the number of units to prune in the tensor | |
""" | |
# incorrect type already checked in _validate_pruning_amount_init | |
if isinstance(amount, numbers.Integral): | |
return amount | |
else: | |
return round(amount * tensor_size) | |
def _validate_pruning_dim(t, dim): | |
r"""Validate that the pruning dimension is within the bounds of the tensor dimension. | |
Args: | |
t (torch.Tensor): tensor representing the parameter to prune | |
dim (int): index of the dim along which we define channels to prune | |
""" | |
if dim >= t.dim(): | |
raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}") | |
def _compute_norm(t, n, dim): | |
r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension. | |
The L_n-norm will be computed across all entries in tensor `t` along all dimension | |
except for the one identified by dim. | |
Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), | |
then norm will have Size [4], and each entry will represent the | |
`L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. | |
Args: | |
t (torch.Tensor): tensor representing the parameter to prune | |
n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid | |
entries for argument p in torch.norm | |
dim (int): dim identifying the channels to prune | |
Returns: | |
norm (torch.Tensor): L_n norm computed across all dimensions except | |
for `dim`. By construction, `norm.shape = t.shape[-1]`. | |
""" | |
# dims = all axes, except for the one identified by `dim` | |
dims = list(range(t.dim())) | |
# convert negative indexing | |
if dim < 0: | |
dim = dims[dim] | |
dims.remove(dim) | |
norm = torch.norm(t, p=n, dim=dims) | |
return norm | |