Spaces:
Running
Running
from typing import Any, Dict, Optional, Type | |
from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized | |
from itertools import chain | |
from torch import nn | |
__all__ = [ | |
"module_contains_param", | |
"swap_module", | |
"module_to_fqn", | |
"fqn_to_module", | |
"get_arg_info_from_tensor_fqn", | |
"FakeSparsity", | |
] | |
def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool: | |
if is_parametrized(module): | |
# see if any of the module tensors have a parametriztion attached that matches the one passed in | |
return any( | |
any(isinstance(param, parametrization) for param in param_list) | |
for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] | |
) | |
return False | |
def swap_module( | |
mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]] | |
) -> nn.Module: | |
r"""Swaps the module using from_dense according to the mapping passed in. | |
Args: | |
mod: input module | |
mapping: a dictionary that maps from nn module to sparse nn module | |
Return: | |
The corresponding sparse module of `mod` according to mapping, created using from_dense | |
""" | |
if type_before_parametrizations(mod) in mapping: | |
sparse_mod = mapping[type_before_parametrizations(mod)] | |
# TODO Fix this typing, as Type[Module] has no attribute "from_dense" | |
new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] | |
# Preserve module's pre forward hooks. They'll be called on quantized input | |
for pre_hook_fn in mod._forward_pre_hooks.values(): | |
new_mod.register_forward_pre_hook(pre_hook_fn) | |
# Preserve module's post forward hooks except _observer_forward_hook | |
# After convert they'll work with quantized output | |
for hook_fn in mod._forward_hooks.values(): | |
new_mod.register_forward_hook(hook_fn) | |
# respect device affinity when swapping modules | |
devices = {p.device for p in chain(mod.parameters(), mod.buffers())} | |
assert len(devices) <= 1, ( | |
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" | |
) | |
device = next(iter(devices)) if len(devices) > 0 else None | |
if device: | |
new_mod.to(device) | |
return new_mod | |
else: | |
return mod | |
def module_to_fqn( | |
model: nn.Module, module: nn.Module, prefix: str = "" | |
) -> Optional[str]: | |
""" | |
Returns the fqn for a module or None if module not a descendent of model. | |
""" | |
if module is model: | |
return "" | |
for name, child in model.named_children(): | |
fqn = module_to_fqn(child, module, ".") | |
if isinstance(fqn, str): | |
return prefix + name + fqn | |
return None | |
def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: | |
""" | |
Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` | |
doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. | |
""" | |
if path != "": | |
for name in path.split("."): | |
model = getattr(model, name, None) | |
return model | |
def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: | |
""" | |
Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name | |
""" | |
# string manip to split tensor_fqn into module_fqn and tensor_name | |
# if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' | |
# if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' | |
tensor_name = tensor_fqn.split(".")[-1] | |
module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] | |
module = fqn_to_module(model, module_fqn) | |
return { | |
"module_fqn": module_fqn, | |
"module": module, | |
"tensor_name": tensor_name, | |
"tensor_fqn": tensor_fqn, | |
} | |
# Parametrizations | |
class FakeSparsity(nn.Module): | |
r"""Parametrization for the weights. Should be attached to the 'weight' or | |
any other parameter that requires a mask applied to it. | |
Note:: | |
Once the mask is passed, the variable should not change the id. The | |
contents of the mask can change, but the mask reference itself should | |
not. | |
""" | |
def __init__(self, mask): | |
super().__init__() | |
self.register_buffer("mask", mask) | |
def forward(self, x): | |
assert self.mask.shape == x.shape | |
return self.mask * x | |
def state_dict(self, *args, **kwargs): | |
# We don't want to let the parametrizations to save the mask. | |
# That way we make sure that the linear module doesn't store the masks | |
# alongside their parametrizations. | |
return {} | |