File size: 4,811 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Optional, Union

import torch


class _remote_device:
    """

    Represents a device on a remote worker.



    Args:

        remote_device (str or torch.device): Represents a device on a remote worker.

            The string format should be one of the following:



                1. "<workername>/<device>", where the device field can be parsed as torch.device type.

                   E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".

                   In addition, the device field can be optional and the default value is "cpu".

                2. "rank:<rank>/<device>", where <rank> is the rank of the

                   process and device can be parsed as torch.device type.

                   E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"

                3. <workername> and <rank> are optional and formats like "cpu"

                    and "cuda:1", just represent local devices.

    """

    def __init__(self, remote_device: Union[str, torch.device]):
        PARSE_ERROR = (
            f"Could not parse remote_device: {remote_device}. The valid format is "
            "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
        )
        self._worker_name = None
        self._rank = None
        self._device: Optional[Union[str, int, torch.device]] = None

        if isinstance(remote_device, torch.device):
            self._device = remote_device
        elif isinstance(remote_device, str):
            fields = remote_device.split("/")
            if len(fields) == 2:
                self._worker_name, self._device = fields
            elif len(fields) == 1:
                # Check if this is a valid device.
                if _remote_device._is_valid_local_device(fields[0]):
                    self._device = fields[0]
                else:
                    self._worker_name = fields[0]
                    self._device = "cpu"
            else:
                raise ValueError(PARSE_ERROR)
        else:
            raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')

        # Do some basic sanity check (no empty string)
        if self._worker_name is not None and not self._worker_name:
            raise ValueError(PARSE_ERROR)

        # Validate the device.
        self._device = torch.device(self._device)

        # Check for rank based format.
        if self._worker_name is not None:
            fields = self._worker_name.split(":")
            if len(fields) == 2:
                # rank:<rank>/device format, extract rank
                if fields[0] == "rank" and fields[1].isdigit():
                    self._rank = int(fields[1])  # type: ignore[assignment]
                    self._worker_name = None
                else:
                    raise ValueError(PARSE_ERROR)
            elif len(fields) > 2:
                raise ValueError(PARSE_ERROR)

    @staticmethod
    def _is_valid_local_device(device):
        # Check for torch.device
        try:
            torch.device(device)
            return True
        except Exception:
            return False

    def worker_name(self) -> Optional[str]:
        """Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
        return self._worker_name

    def rank(self) -> Optional[int]:
        """

        Returns the rank of remote worker representing the remote device.

        Returns ``None`` if no rank is available.

        """
        return self._rank

    def device(self) -> torch.device:
        """Return the local device on the remote worker."""
        return self._device  # type: ignore[return-value]

    def __repr__(self):
        if self._device is not None:
            if self._worker_name is not None:
                return f'{self._worker_name}/{self._device}'
            elif self._rank is not None:
                return f'rank:{self._rank}/{self._device}'
            else:
                return str(self._device)
        else:
            if self._worker_name is not None:
                return f'{self._worker_name}'
            elif self._rank is not None:
                return f'{self._rank}'
            else:
                raise RuntimeError('Invalid state!')

    def __eq__(self, other):
        if not isinstance(other, _remote_device):
            return False

        if (
            self._worker_name == other._worker_name
            and self._device == other._device
            and self._rank == other._rank
        ):
            return True

        return False


    def __hash__(self):
        return hash(self._worker_name) ^ \
            hash(self._device) ^ \
            hash(self._rank)