Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import json | |
import logging | |
from .mutables import InputChoice, LayerChoice, MutableScope | |
from .mutator import Mutator | |
from .utils import to_list | |
_logger = logging.getLogger(__name__) | |
class FixedArchitecture(Mutator): | |
""" | |
Fixed architecture mutator that always selects a certain graph. | |
Parameters | |
---------- | |
model : nn.Module | |
A mutable network. | |
fixed_arc : dict | |
Preloaded architecture object. | |
strict : bool | |
Force everything that appears in ``fixed_arc`` to be used at least once. | |
verbose : bool | |
Print log messages if set to True | |
""" | |
def __init__(self, model, fixed_arc, strict=True, verbose=True): | |
super().__init__(model) | |
self._fixed_arc = fixed_arc | |
self.verbose = verbose | |
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)]) | |
fixed_arc_keys = set(self._fixed_arc.keys()) | |
if fixed_arc_keys - mutable_keys: | |
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | |
if mutable_keys - fixed_arc_keys: | |
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | |
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc) | |
def _from_human_readable_architecture(self, human_arc): | |
# convert from an exported architecture | |
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. | |
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, | |
# which means {"op1": [0, ]} ir {"op1": ["conv", ]} | |
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()} | |
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format. | |
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. | |
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length. | |
for mutable in self.mutables: | |
if mutable.key not in result_arc: | |
continue # skip silently | |
choice_arr = result_arc[mutable.key] | |
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): | |
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ | |
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): | |
# multihot, do nothing | |
continue | |
if isinstance(mutable, LayerChoice): | |
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] | |
choice_arr = [i in choice_arr for i in range(len(mutable))] | |
elif isinstance(mutable, InputChoice): | |
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] | |
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] | |
result_arc[mutable.key] = choice_arr | |
return result_arc | |
def sample_search(self): | |
""" | |
Always returns the fixed architecture. | |
""" | |
return self._fixed_arc | |
def sample_final(self): | |
""" | |
Always returns the fixed architecture. | |
""" | |
return self._fixed_arc | |
def replace_layer_choice(self, module=None, prefix=""): | |
""" | |
Replace layer choices with selected candidates. It's done with best effort. | |
In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them. | |
If single choice, replace the module with a normal module. | |
Parameters | |
---------- | |
module : nn.Module | |
Module to be processed. | |
prefix : str | |
Module name under global namespace. | |
""" | |
if module is None: | |
module = self.model | |
for name, mutable in module.named_children(): | |
global_name = (prefix + "." if prefix else "") + name | |
if isinstance(mutable, LayerChoice): | |
chosen = self._fixed_arc[mutable.key] | |
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask: | |
# sum is one, max is one, there has to be an only one | |
# this is compatible with both integer arrays, boolean arrays and float arrays | |
if self.verbose: | |
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1)) | |
setattr(module, name, mutable[chosen.index(1)]) | |
else: | |
if mutable.return_mask and self.verbose: | |
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \ | |
"LayerChoice will not be replaced.") | |
# remove unused parameters | |
for ch, n in zip(chosen, mutable.names): | |
if ch == 0 and not isinstance(ch, float): | |
setattr(mutable, n, None) | |
else: | |
self.replace_layer_choice(mutable, global_name) | |
def apply_fixed_architecture(model, fixed_arc, verbose=True): | |
""" | |
Load architecture from `fixed_arc` and apply to model. | |
Parameters | |
---------- | |
model : torch.nn.Module | |
Model with mutables. | |
fixed_arc : str or dict | |
Path to the JSON that stores the architecture, or dict that stores the exported architecture. | |
verbose : bool | |
Print log messages if set to True | |
Returns | |
------- | |
FixedArchitecture | |
Mutator that is responsible for fixes the graph. | |
""" | |
if isinstance(fixed_arc, str): | |
with open(fixed_arc) as f: | |
fixed_arc = json.load(f) | |
architecture = FixedArchitecture(model, fixed_arc, verbose) | |
architecture.reset() | |
# for the convenience of parameters counting | |
architecture.replace_layer_choice() | |
return architecture | |