Spaces:
Sleeping
Sleeping
File size: 10,651 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import logging
import multiprocessing
import multiprocessing.connection
import os
import pickle
import signal
import sys
import tempfile
import time
import warnings
from typing import Optional
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
log = logging.getLogger(__name__)
class ProcessException(Exception):
__slots__ = ["error_index", "error_pid"]
def __init__(self, msg: str, error_index: int, pid: int):
super().__init__(msg)
self.msg = msg
self.error_index = error_index
self.pid = pid
def __reduce__(self):
return type(self), (self.msg, self.error_index, self.pid)
class ProcessRaisedException(ProcessException):
"""Exception raised when a process failed due to an exception raised by the code."""
def __init__(
self,
msg: str,
error_index: int,
error_pid: int,
):
super().__init__(msg, error_index, error_pid)
class ProcessExitedException(ProcessException):
"""Exception raised when a process failed due to signal or exited with a specific code."""
__slots__ = ["exit_code"]
def __init__(
self,
msg: str,
error_index: int,
error_pid: int,
exit_code: int,
signal_name: Optional[str] = None,
):
super().__init__(msg, error_index, error_pid)
self.exit_code = exit_code
self.signal_name = signal_name
def __reduce__(self):
return (
type(self),
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
)
def _wrap(fn, i, args, error_file):
# prctl(2) is a Linux specific system call.
# On other systems the following function call has no effect.
# This is set to ensure that non-daemonic child processes can
# terminate if their parent terminates before they do.
_prctl_pr_set_pdeathsig(signal.SIGINT)
try:
fn(i, *args)
except KeyboardInterrupt:
pass # SIGINT; Killed by parent, do nothing
except Exception:
# Propagate exception to parent process, keeping original traceback
import traceback
with open(error_file, "wb") as fh:
pickle.dump(traceback.format_exc(), fh)
sys.exit(1)
class ProcessContext:
def __init__(self, processes, error_files):
self.error_files = error_files
self.processes = processes
self.sentinels = {
process.sentinel: index for index, process in enumerate(processes)
}
def pids(self):
return [int(process.pid) for process in self.processes]
def join(self, timeout=None):
r"""Join one or more processes within spawn context.
Attempt to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Args:
timeout (float): Wait this long before giving up on waiting.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# Assume failure. Terminate processes that are still alive.
# Try SIGTERM then SIGKILL if the process isn't going down.
# The reason is related to python signal handling is limited
# to main thread and if that is in c/c++ land and stuck it won't
# to handle it. We have seen processes getting stuck not handling
# SIGTERM for the above reason.
timeout: int = 30
for process in self.processes:
if process.is_alive():
log.warning("Terminating process %s via signal SIGTERM", process.pid)
process.terminate()
end = time.monotonic() + timeout
for process in self.processes:
time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait)
for process in self.processes:
if process.is_alive():
log.warning(
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
process.pid,
)
process.kill()
process.join()
# The file will only be created if the process crashed.
failed_process = self.processes[error_index]
if not os.access(self.error_files[error_index], os.R_OK):
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
try:
name = signal.Signals(-exitcode).name
except ValueError:
name = f"<Unknown signal {-exitcode}>"
raise ProcessExitedException(
"process %d terminated with signal %s" % (error_index, name),
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
signal_name=name,
)
else:
raise ProcessExitedException(
"process %d terminated with exit code %d" % (error_index, exitcode),
error_index=error_index,
error_pid=failed_process.pid,
exit_code=exitcode,
)
with open(self.error_files[error_index], "rb") as fh:
original_trace = pickle.load(fh)
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
msg += original_trace
raise ProcessRaisedException(msg, error_index, failed_process.pid)
class SpawnContext(ProcessContext):
def __init__(self, processes, error_files):
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
super().__init__(processes, error_files)
# Note: [start_processes]
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
# general enough, and backends like XLA can reuse them in Colab notebooks as well.
# Currently we only add this API first, we can consider adding it to documentation as
# needed in the future.
def start_processes(
fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"
):
mp = multiprocessing.get_context(start_method)
error_files = []
processes = []
for i in range(nprocs):
# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
# used a multiprocessing.Queue but that can be prone to
# deadlocks, so we went with a simpler solution for a one-shot
# message between processes.
tf = tempfile.NamedTemporaryFile(
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
)
tf.close()
os.unlink(tf.name)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
process.start()
error_files.append(tf.name)
processes.append(process)
context = ProcessContext(processes, error_files)
if not join:
return context
# Loop on join until it returns True or raises an exception.
while not context.join():
pass
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
If one of the processes exits with a non-zero exit status, the
remaining processes are killed and an exception is raised with the
cause of termination. In the case an exception was caught in the
child process, it is forwarded and its traceback is included in
the exception raised in the parent process.
Args:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.
The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.
args (tuple): Arguments passed to ``fn``.
nprocs (int): Number of processes to spawn.
join (bool): Perform a blocking join on all processes.
daemon (bool): The spawned processes' daemon flag. If set to True,
daemonic processes will be created.
start_method (str): (deprecated) this method will always use ``spawn``
as the start method. To use a different start method
use ``start_processes()``.
Returns:
None if ``join`` is ``True``,
:class:`~ProcessContext` if ``join`` is ``False``
"""
if start_method != "spawn":
msg = (
"This method only supports start_method=spawn (got: %s).\n"
"To use a different start_method use:\n\t\t"
" torch.multiprocessing.start_processes(...)" % start_method
)
warnings.warn(msg)
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
|