Spaces:
Runtime error
Runtime error
File size: 4,673 Bytes
19c4ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
class AttrDict(OrderedDict):
"""
An attribute dictionary that automatically handles nested keys joined by "/".
Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access
"""
MARKER = object()
# pylint: disable=super-init-not-called
def __init__(self, *args, **kwargs):
if len(args) == 0:
for key, value in kwargs.items():
self.__setitem__(key, value)
else:
assert len(args) == 1
assert isinstance(args[0], (dict, AttrDict))
for key, value in args[0].items():
self.__setitem__(key, value)
def __contains__(self, key):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
return key in self and next_key in self[key]
return super(AttrDict, self).__contains__(key)
def __setitem__(self, key, value):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
if key not in self:
self[key] = AttrDict()
self[key].__setitem__(next_key, value)
return
if isinstance(value, dict) and not isinstance(value, AttrDict):
value = AttrDict(**value)
if isinstance(value, list):
value = [AttrDict(val) if isinstance(val, dict) else val for val in value]
super(AttrDict, self).__setitem__(key, value)
def __getitem__(self, key):
if "/" in key:
keys = key.split("/")
key, next_key = keys[0], "/".join(keys[1:])
val = self[key]
if not isinstance(val, AttrDict):
raise ValueError
return val.__getitem__(next_key)
return self.get(key, None)
def all_keys(
self,
leaves_only: bool = False,
parent: Optional[str] = None,
) -> List[str]:
keys = []
for key in self.keys():
cur = key if parent is None else f"{parent}/{key}"
if not leaves_only or not isinstance(self[key], dict):
keys.append(cur)
if isinstance(self[key], dict):
keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur))
return keys
def dumpable(self, strip=True):
"""
Casts into OrderedDict and removes internal attributes
"""
def _dump(val):
if isinstance(val, AttrDict):
return val.dumpable()
elif isinstance(val, list):
return [_dump(v) for v in val]
return val
if strip:
return {k: _dump(v) for k, v in self.items() if not k.startswith("_")}
return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()}
def map(
self,
map_fn: Callable[[Any, Any], Any],
should_map: Optional[Callable[[Any, Any], bool]] = None,
) -> "AttrDict":
"""
Creates a copy of self where some or all values are transformed by
map_fn.
:param should_map: If provided, only those values that evaluate to true
are converted; otherwise, all values are mapped.
"""
def _apply(key, val):
if isinstance(val, AttrDict):
return val.map(map_fn, should_map)
elif should_map is None or should_map(key, val):
return map_fn(key, val)
return val
return AttrDict({k: _apply(k, v) for k, v in self.items()})
def __eq__(self, other):
return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys())
def combine(
self,
other: Dict[str, Any],
combine_fn: Callable[[Optional[Any], Optional[Any]], Any],
) -> "AttrDict":
"""
Some values may be missing, but the dictionary structures must be the
same.
:param combine_fn: a (possibly non-commutative) function to combine the
values
"""
def _apply(val, other_val):
if val is not None and isinstance(val, AttrDict):
assert isinstance(other_val, AttrDict)
return val.combine(other_val, combine_fn)
return combine_fn(val, other_val)
# TODO nit: this changes the ordering..
keys = self.keys() | other.keys()
return AttrDict({k: _apply(self[k], other[k]) for k in keys})
__setattr__, __getattr__ = __setitem__, __getitem__
|