File size: 9,183 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import importlib
from abc import ABC, abstractmethod
from pickle import (  # type: ignore[attr-defined]  # type: ignore[attr-defined]
    _getattribute,
    _Pickler,
    whichmodule as _pickle_whichmodule,
)
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple

from ._mangling import demangle, get_mangle_prefix, is_mangled

__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]


class ObjNotFoundError(Exception):
    """Raised when an importer cannot find an object by searching for its name."""

    pass


class ObjMismatchError(Exception):
    """Raised when an importer found a different object with the same name as the user-provided one."""

    pass


class Importer(ABC):
    """Represents an environment to import modules from.



    By default, you can figure out what module an object belongs by checking

    __module__ and importing the result using __import__ or importlib.import_module.



    torch.package introduces module importers other than the default one.

    Each PackageImporter introduces a new namespace. Potentially a single

    name (e.g. 'foo.bar') is present in multiple namespaces.



    It supports two main operations:

        import_module: module_name -> module object

        get_name: object -> (parent module name, name of obj within module)



    The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.

        module_name, obj_name = env.get_name(obj)

        module = env.import_module(module_name)

        obj2 = getattr(module, obj_name)

        assert obj1 is obj2

    """

    modules: Dict[str, ModuleType]

    @abstractmethod
    def import_module(self, module_name: str) -> ModuleType:
        """Import `module_name` from this environment.



        The contract is the same as for importlib.import_module.

        """
        pass

    def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
        """Given an object, return a name that can be used to retrieve the

        object from this environment.



        Args:

            obj: An object to get the module-environment-relative name for.

            name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.

                This is only here to match how Pickler handles __reduce__ functions that return a string,

                don't use otherwise.

        Returns:

            A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.

            Use it like:

                mod = importer.import_module(parent_module_name)

                obj = getattr(mod, attr_name)



        Raises:

            ObjNotFoundError: we couldn't retrieve `obj by name.

            ObjMisMatchError: we found a different object with the same name as `obj`.

        """
        if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
            # Honor the string return variant of __reduce__, which will give us
            # a global name to search for in this environment.
            # TODO: I guess we should do copyreg too?
            reduce = getattr(obj, "__reduce__", None)
            if reduce is not None:
                try:
                    rv = reduce()
                    if isinstance(rv, str):
                        name = rv
                except Exception:
                    pass
        if name is None:
            name = getattr(obj, "__qualname__", None)
        if name is None:
            name = obj.__name__

        orig_module_name = self.whichmodule(obj, name)
        # Demangle the module name before importing. If this obj came out of a
        # PackageImporter, `__module__` will be mangled. See mangling.md for
        # details.
        module_name = demangle(orig_module_name)

        # Check that this name will indeed return the correct object
        try:
            module = self.import_module(module_name)
            obj2, _ = _getattribute(module, name)
        except (ImportError, KeyError, AttributeError):
            raise ObjNotFoundError(
                f"{obj} was not found as {module_name}.{name}"
            ) from None

        if obj is obj2:
            return module_name, name

        def get_obj_info(obj):
            assert name is not None
            module_name = self.whichmodule(obj, name)
            is_mangled_ = is_mangled(module_name)
            location = (
                get_mangle_prefix(module_name)
                if is_mangled_
                else "the current Python environment"
            )
            importer_name = (
                f"the importer for {get_mangle_prefix(module_name)}"
                if is_mangled_
                else "'sys_importer'"
            )
            return module_name, location, importer_name

        obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
        obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
        msg = (
            f"\n\nThe object provided is from '{obj_module_name}', "
            f"which is coming from {obj_location}."
            f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
            "\nTo fix this, make sure this 'PackageExporter's importer lists "
            f"{obj_importer_name} before {obj2_importer_name}."
        )
        raise ObjMismatchError(msg)

    def whichmodule(self, obj: Any, name: str) -> str:
        """Find the module name an object belongs to.



        This should be considered internal for end-users, but developers of

        an importer can override it to customize the behavior.



        Taken from pickle.py, but modified to exclude the search into sys.modules

        """
        module_name = getattr(obj, "__module__", None)
        if module_name is not None:
            return module_name

        # Protect the iteration by using a list copy of self.modules against dynamic
        # modules that trigger imports of other modules upon calls to getattr.
        for module_name, module in self.modules.copy().items():
            if (
                module_name == "__main__"
                or module_name == "__mp_main__"  # bpo-42406
                or module is None
            ):
                continue
            try:
                if _getattribute(module, name)[0] is obj:
                    return module_name
            except AttributeError:
                pass

        return "__main__"


class _SysImporter(Importer):
    """An importer that implements the default behavior of Python."""

    def import_module(self, module_name: str):
        return importlib.import_module(module_name)

    def whichmodule(self, obj: Any, name: str) -> str:
        return _pickle_whichmodule(obj, name)


sys_importer = _SysImporter()


class OrderedImporter(Importer):
    """A compound importer that takes a list of importers and tries them one at a time.



    The first importer in the list that returns a result "wins".

    """

    def __init__(self, *args):
        self._importers: List[Importer] = list(args)

    def _is_torchpackage_dummy(self, module):
        """Returns true iff this module is an empty PackageNode in a torch.package.



        If you intern `a.b` but never use `a` in your code, then `a` will be an

        empty module with no source. This can break cases where we are trying to

        re-package an object after adding a real dependency on `a`, since

        OrderedImportere will resolve `a` to the dummy package and stop there.



        See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769

        """
        if not getattr(module, "__torch_package__", False):
            return False
        if not hasattr(module, "__path__"):
            return False
        if not hasattr(module, "__file__"):
            return True
        return module.__file__ is None

    def import_module(self, module_name: str) -> ModuleType:
        last_err = None
        for importer in self._importers:
            if not isinstance(importer, Importer):
                raise TypeError(
                    f"{importer} is not a Importer. "
                    "All importers in OrderedImporter must inherit from Importer."
                )
            try:
                module = importer.import_module(module_name)
                if self._is_torchpackage_dummy(module):
                    continue
                return module
            except ModuleNotFoundError as err:
                last_err = err

        if last_err is not None:
            raise last_err
        else:
            raise ModuleNotFoundError(module_name)

    def whichmodule(self, obj: Any, name: str) -> str:
        for importer in self._importers:
            module_name = importer.whichmodule(obj, name)
            if module_name != "__main__":
                return module_name

        return "__main__"