Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import logging | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| _counter = 0 | |
| _logger = logging.getLogger(__name__) | |
| def global_mutable_counting(): | |
| """ | |
| A program level counter starting from 1. | |
| """ | |
| global _counter | |
| _counter += 1 | |
| return _counter | |
| def _reset_global_mutable_counting(): | |
| """ | |
| Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys. | |
| """ | |
| global _counter | |
| _counter = 0 | |
| def to_device(obj, device): | |
| """ | |
| Move a tensor, tuple, list, or dict onto device. | |
| """ | |
| if torch.is_tensor(obj): | |
| return obj.to(device) | |
| if isinstance(obj, tuple): | |
| return tuple(to_device(t, device) for t in obj) | |
| if isinstance(obj, list): | |
| return [to_device(t, device) for t in obj] | |
| if isinstance(obj, dict): | |
| return {k: to_device(v, device) for k, v in obj.items()} | |
| if isinstance(obj, (int, float, str)): | |
| return obj | |
| raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj))) | |
| def to_list(arr): | |
| if torch.is_tensor(arr): | |
| return arr.cpu().numpy().tolist() | |
| if isinstance(arr, np.ndarray): | |
| return arr.tolist() | |
| if isinstance(arr, (list, tuple)): | |
| return list(arr) | |
| return arr | |
| class AverageMeterGroup: | |
| """ | |
| Average meter group for multiple average meters. | |
| """ | |
| def __init__(self): | |
| self.meters = OrderedDict() | |
| def update(self, data): | |
| """ | |
| Update the meter group with a dict of metrics. | |
| Non-exist average meters will be automatically created. | |
| """ | |
| for k, v in data.items(): | |
| if k not in self.meters: | |
| self.meters[k] = AverageMeter(k, ":4f") | |
| self.meters[k].update(v) | |
| def __getattr__(self, item): | |
| return self.meters[item] | |
| def __getitem__(self, item): | |
| return self.meters[item] | |
| def __str__(self): | |
| return " ".join(str(v) for v in self.meters.values()) | |
| def summary(self): | |
| """ | |
| Return a summary string of group data. | |
| """ | |
| return " ".join(v.summary() for v in self.meters.values()) | |
| class AverageMeter: | |
| """ | |
| Computes and stores the average and current value. | |
| Parameters | |
| ---------- | |
| name : str | |
| Name to display. | |
| fmt : str | |
| Format string to print the values. | |
| """ | |
| def __init__(self, name, fmt=':f'): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| """ | |
| Reset the meter. | |
| """ | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| """ | |
| Update with value and weight. | |
| Parameters | |
| ---------- | |
| val : float or int | |
| The new value to be accounted in. | |
| n : int | |
| The weight of the new value. | |
| """ | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
| return fmtstr.format(**self.__dict__) | |
| def summary(self): | |
| fmtstr = '{name}: {avg' + self.fmt + '}' | |
| return fmtstr.format(**self.__dict__) | |
| class StructuredMutableTreeNode: | |
| """ | |
| A structured representation of a search space. | |
| A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. | |
| This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, | |
| the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a | |
| ``Mutable`` (other than ``MutableScope``). | |
| Parameters | |
| ---------- | |
| mutable : nni.nas.pytorch.mutables.Mutable | |
| The mutable that current node is linked with. | |
| """ | |
| def __init__(self, mutable): | |
| self.mutable = mutable | |
| self.children = [] | |
| def add_child(self, mutable): | |
| """ | |
| Add a tree node to the children list of current node. | |
| """ | |
| self.children.append(StructuredMutableTreeNode(mutable)) | |
| return self.children[-1] | |
| def type(self): | |
| """ | |
| Return the ``type`` of mutable content. | |
| """ | |
| return type(self.mutable) | |
| def __iter__(self): | |
| return self.traverse() | |
| def traverse(self, order="pre", deduplicate=True, memo=None): | |
| """ | |
| Return a generator that generates a list of mutables in this tree. | |
| Parameters | |
| ---------- | |
| order : str | |
| pre or post. If pre, current mutable is yield before children. Otherwise after. | |
| deduplicate : bool | |
| If true, mutables with the same key will not appear after the first appearance. | |
| memo : dict | |
| An auxiliary dict that memorize keys seen before, so that deduplication is possible. | |
| Returns | |
| ------- | |
| generator of Mutable | |
| """ | |
| if memo is None: | |
| memo = set() | |
| assert order in ["pre", "post"] | |
| if order == "pre": | |
| if self.mutable is not None: | |
| if not deduplicate or self.mutable.key not in memo: | |
| memo.add(self.mutable.key) | |
| yield self.mutable | |
| for child in self.children: | |
| for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): | |
| yield m | |
| if order == "post": | |
| if self.mutable is not None: | |
| if not deduplicate or self.mutable.key not in memo: | |
| memo.add(self.mutable.key) | |
| yield self.mutable | |