File size: 2,565 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.



It registers custom reducers, that use shared memory to provide shared

views on the same data in different processes. Once the tensor/storage is moved

to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible

to send it to other processes without making any copies.



The API is 100% compatible with the original module - it's enough to change

``import multiprocessing`` to ``import torch.multiprocessing`` to have all the

tensors sent through the queues or shared via other mechanisms, moved to shared

memory.



Because of the similarity of APIs we do not document most of this package

contents, and we recommend referring to very good docs of the original module.

"""
import multiprocessing
import sys

import torch
from .reductions import init_reductions

__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]


from multiprocessing import *  # noqa: F403


__all__ += multiprocessing.__all__  # noqa: PLE0605 type: ignore[attr-defined]


# This call adds a Linux specific prctl(2) wrapper function to this module.
# See https://github.com/pytorch/pytorch/pull/14391 for more information.
torch._C._multiprocessing_init()


"""Add helper function to spawn N processes and wait for completion of any of

them. This depends `mp.get_context` which was added in Python 3.4."""
from .spawn import (
    ProcessContext,
    ProcessExitedException,
    ProcessRaisedException,
    spawn,
    SpawnContext,
    start_processes,
)


if sys.platform == "darwin" or sys.platform == "win32":
    _sharing_strategy = "file_system"
    _all_sharing_strategies = {"file_system"}
else:
    _sharing_strategy = "file_descriptor"
    _all_sharing_strategies = {"file_descriptor", "file_system"}


def set_sharing_strategy(new_strategy):
    """Set the strategy for sharing CPU tensors.



    Args:

        new_strategy (str): Name of the selected strategy. Should be one of

            the values returned by :func:`get_all_sharing_strategies()`.

    """
    global _sharing_strategy
    assert new_strategy in _all_sharing_strategies
    _sharing_strategy = new_strategy


def get_sharing_strategy():
    """Return the current strategy for sharing CPU tensors."""
    return _sharing_strategy


def get_all_sharing_strategies():
    """Return a set of sharing strategies supported on a current system."""
    return _all_sharing_strategies


init_reductions()