File size: 4,673 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional


class AttrDict(OrderedDict):
    """
    An attribute dictionary that automatically handles nested keys joined by "/".

    Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access
    """

    MARKER = object()

    # pylint: disable=super-init-not-called
    def __init__(self, *args, **kwargs):
        if len(args) == 0:
            for key, value in kwargs.items():
                self.__setitem__(key, value)
        else:
            assert len(args) == 1
            assert isinstance(args[0], (dict, AttrDict))
            for key, value in args[0].items():
                self.__setitem__(key, value)

    def __contains__(self, key):
        if "/" in key:
            keys = key.split("/")
            key, next_key = keys[0], "/".join(keys[1:])
            return key in self and next_key in self[key]
        return super(AttrDict, self).__contains__(key)

    def __setitem__(self, key, value):
        if "/" in key:
            keys = key.split("/")
            key, next_key = keys[0], "/".join(keys[1:])
            if key not in self:
                self[key] = AttrDict()
            self[key].__setitem__(next_key, value)
            return

        if isinstance(value, dict) and not isinstance(value, AttrDict):
            value = AttrDict(**value)
        if isinstance(value, list):
            value = [AttrDict(val) if isinstance(val, dict) else val for val in value]
        super(AttrDict, self).__setitem__(key, value)

    def __getitem__(self, key):
        if "/" in key:
            keys = key.split("/")
            key, next_key = keys[0], "/".join(keys[1:])
            val = self[key]
            if not isinstance(val, AttrDict):
                raise ValueError
            return val.__getitem__(next_key)

        return self.get(key, None)

    def all_keys(
        self,
        leaves_only: bool = False,
        parent: Optional[str] = None,
    ) -> List[str]:
        keys = []
        for key in self.keys():
            cur = key if parent is None else f"{parent}/{key}"
            if not leaves_only or not isinstance(self[key], dict):
                keys.append(cur)
            if isinstance(self[key], dict):
                keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur))
        return keys

    def dumpable(self, strip=True):
        """
        Casts into OrderedDict and removes internal attributes
        """

        def _dump(val):
            if isinstance(val, AttrDict):
                return val.dumpable()
            elif isinstance(val, list):
                return [_dump(v) for v in val]
            return val

        if strip:
            return {k: _dump(v) for k, v in self.items() if not k.startswith("_")}
        return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()}

    def map(
        self,
        map_fn: Callable[[Any, Any], Any],
        should_map: Optional[Callable[[Any, Any], bool]] = None,
    ) -> "AttrDict":
        """
        Creates a copy of self where some or all values are transformed by
        map_fn.

        :param should_map: If provided, only those values that evaluate to true
            are converted; otherwise, all values are mapped.
        """

        def _apply(key, val):
            if isinstance(val, AttrDict):
                return val.map(map_fn, should_map)
            elif should_map is None or should_map(key, val):
                return map_fn(key, val)
            return val

        return AttrDict({k: _apply(k, v) for k, v in self.items()})

    def __eq__(self, other):
        return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys())

    def combine(
        self,
        other: Dict[str, Any],
        combine_fn: Callable[[Optional[Any], Optional[Any]], Any],
    ) -> "AttrDict":
        """
        Some values may be missing, but the dictionary structures must be the
        same.

        :param combine_fn: a (possibly non-commutative) function to combine the
            values
        """

        def _apply(val, other_val):
            if val is not None and isinstance(val, AttrDict):
                assert isinstance(other_val, AttrDict)
                return val.combine(other_val, combine_fn)
            return combine_fn(val, other_val)

        # TODO nit: this changes the ordering..
        keys = self.keys() | other.keys()
        return AttrDict({k: _apply(self[k], other[k]) for k in keys})

    __setattr__, __getattr__ = __setitem__, __getitem__