Spaces:
Sleeping
Sleeping
import io | |
import pickle | |
import warnings | |
from collections.abc import Collection | |
from typing import Dict, List, Optional, Set, Tuple, Type, Union | |
from torch.utils.data import IterDataPipe, MapDataPipe | |
from torch.utils._import_utils import dill_available | |
__all__ = ["traverse", "traverse_dps"] | |
DataPipe = Union[IterDataPipe, MapDataPipe] | |
DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc] | |
def _stub_unpickler(): | |
return "STUB" | |
# TODO(VitalyFedyunin): Make sure it works without dill module installed | |
def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]: | |
f = io.BytesIO() | |
p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is | |
if dill_available(): | |
from dill import Pickler as dill_Pickler | |
d = dill_Pickler(f) | |
else: | |
d = None | |
captured_connections = [] | |
def getstate_hook(ori_state): | |
state = None | |
if isinstance(ori_state, dict): | |
state = {} # type: ignore[assignment] | |
for k, v in ori_state.items(): | |
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): | |
state[k] = v # type: ignore[attr-defined] | |
elif isinstance(ori_state, (tuple, list)): | |
state = [] # type: ignore[assignment] | |
for v in ori_state: | |
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): | |
state.append(v) # type: ignore[attr-defined] | |
elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)): | |
state = ori_state # type: ignore[assignment] | |
return state | |
def reduce_hook(obj): | |
if obj == scan_obj or id(obj) in cache: | |
raise NotImplementedError | |
else: | |
captured_connections.append(obj) | |
# Adding id to remove duplicate DataPipe serialized at the same level | |
cache.add(id(obj)) | |
return _stub_unpickler, () | |
datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment] | |
try: | |
for cls in datapipe_classes: | |
cls.set_reduce_ex_hook(reduce_hook) | |
if only_datapipe: | |
cls.set_getstate_hook(getstate_hook) | |
try: | |
p.dump(scan_obj) | |
except (pickle.PickleError, AttributeError, TypeError): | |
if dill_available(): | |
d.dump(scan_obj) | |
else: | |
raise | |
finally: | |
for cls in datapipe_classes: | |
cls.set_reduce_ex_hook(None) | |
if only_datapipe: | |
cls.set_getstate_hook(None) | |
if dill_available(): | |
from dill import extend as dill_extend | |
dill_extend(False) # Undo change to dispatch table | |
return captured_connections | |
def traverse_dps(datapipe: DataPipe) -> DataPipeGraph: | |
r""" | |
Traverse the DataPipes and their attributes to extract the DataPipe graph. | |
This only looks into the attribute from each DataPipe that is either a | |
DataPipe and a Python collection object such as ``list``, ``tuple``, | |
``set`` and ``dict``. | |
Args: | |
datapipe: the end DataPipe of the graph | |
Returns: | |
A graph represented as a nested dictionary, where keys are ids of DataPipe instances | |
and values are tuples of DataPipe instance and the sub-graph | |
""" | |
cache: Set[int] = set() | |
return _traverse_helper(datapipe, only_datapipe=True, cache=cache) | |
def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph: | |
r""" | |
Traverse the DataPipes and their attributes to extract the DataPipe graph. | |
[Deprecated] | |
When ``only_dataPipe`` is specified as ``True``, it would only look into the | |
attribute from each DataPipe that is either a DataPipe and a Python collection object | |
such as ``list``, ``tuple``, ``set`` and ``dict``. | |
Note: | |
This function is deprecated. Please use `traverse_dps` instead. | |
Args: | |
datapipe: the end DataPipe of the graph | |
only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed. | |
This argument is deprecating and will be removed after the next release. | |
Returns: | |
A graph represented as a nested dictionary, where keys are ids of DataPipe instances | |
and values are tuples of DataPipe instance and the sub-graph | |
""" | |
msg = "`traverse` function and will be removed after 1.13. " \ | |
"Please use `traverse_dps` instead." | |
if not only_datapipe: | |
msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`." | |
warnings.warn(msg, FutureWarning) | |
if only_datapipe is None: | |
only_datapipe = False | |
cache: Set[int] = set() | |
return _traverse_helper(datapipe, only_datapipe, cache) | |
# Add cache here to prevent infinite recursion on DataPipe | |
def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph: | |
if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): | |
raise RuntimeError(f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found") | |
dp_id = id(datapipe) | |
if dp_id in cache: | |
return {} | |
cache.add(dp_id) | |
# Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths | |
items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy()) | |
d: DataPipeGraph = {dp_id: (datapipe, {})} | |
for item in items: | |
# Using cache.copy() here is to prevent recursion on a single path rather than global graph | |
# Single DataPipe can present multiple times in different paths in graph | |
d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy())) | |
return d | |