Spaces:
Running
Running
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import abc | |
import logging | |
import os | |
import re | |
import shutil | |
import signal | |
import subprocess | |
import sys | |
import tempfile | |
import time | |
from contextlib import nullcontext | |
from dataclasses import dataclass, field | |
from enum import IntFlag | |
from multiprocessing import synchronize | |
from types import FrameType | |
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union | |
from abc import ABC, abstractmethod | |
import torch.multiprocessing as mp | |
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record | |
from torch.distributed.elastic.multiprocessing.redirects import ( | |
redirect_stderr, | |
redirect_stdout, | |
) | |
from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler, get_subprocess_handler | |
from torch.distributed.elastic.multiprocessing.tail_log import TailLog | |
IS_WINDOWS = sys.platform == "win32" | |
IS_MACOS = sys.platform == "darwin" | |
log = logging.getLogger(__name__) | |
__all__ = [ | |
"DefaultLogsSpecs", | |
"SignalException", | |
"Std", | |
"to_map", | |
"RunProcsResult", | |
"PContext", | |
"get_std_cm", | |
"MultiprocessContext", | |
"SubprocessContext", | |
] | |
class SignalException(Exception): | |
""" | |
Exception is raised inside the torchelastic agent process by the termination handler | |
if the death signal got received by the process. | |
""" | |
def __init__(self, msg: str, sigval: signal.Signals) -> None: | |
super().__init__(msg) | |
self.sigval = sigval | |
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None: | |
"""Termination handler that raises exceptions on the main process. | |
When the process receives death signal(SIGTERM, SIGINT), this termination handler will | |
be invoked. It raises the ``SignalException`` exception that should be processed by the | |
user code. Python does not terminate process after the termination handler is finished, | |
so the exception should not be silently ignored, otherwise the process will never | |
be terminated. | |
""" | |
sigval = signal.Signals(signum) | |
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) | |
def _get_kill_signal() -> signal.Signals: | |
"""Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" | |
if IS_WINDOWS: | |
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 | |
else: | |
return signal.SIGKILL | |
def _get_default_signal() -> signal.Signals: | |
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" | |
if IS_WINDOWS: | |
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 | |
else: | |
return signal.SIGTERM | |
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str): | |
actual_keys = set(d.keys()) | |
expected_keys = set(range(nprocs)) | |
if actual_keys != expected_keys: | |
raise RuntimeError( | |
f"{what}, local rank mapping mismatch," | |
f" expected: {expected_keys}, actual: {actual_keys}" | |
) | |
_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" | |
_VALUE_REGEX = r"^[0123]$" | |
class Std(IntFlag): | |
NONE = 0 | |
OUT = 1 | |
ERR = 2 | |
ALL = OUT | ERR | |
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]: | |
""" | |
Example: | |
:: | |
from_str("0") -> Std.NONE | |
from_str("1") -> Std.OUT | |
from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} | |
Any other input raises an exception | |
""" | |
def to_std(v: str) -> Std: # type: ignore[return] | |
s = Std(int(v)) | |
if s in Std: | |
return s | |
# return None -> should NEVER reach here since we regex check input | |
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) | |
return to_std(vm) | |
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) | |
d: Dict[int, Std] = {} | |
for m in vm.split(","): | |
i, v = m.split(":") | |
d[int(i)] = to_std(v) | |
return d | |
else: | |
raise ValueError( | |
f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" | |
) | |
def to_map( | |
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int | |
) -> Dict[int, Std]: | |
""" | |
Certain APIs take redirect settings either as a single value (e.g. apply to all | |
local ranks) or as an explicit user-provided mapping. This method is a convenience | |
method that converts a value or mapping into a mapping. | |
Example: | |
:: | |
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} | |
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} | |
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} | |
""" | |
if isinstance(val_or_map, Std): | |
return dict.fromkeys(range(local_world_size), val_or_map) | |
else: | |
map = {} | |
for i in range(local_world_size): | |
map[i] = val_or_map.get(i, Std.NONE) | |
return map | |
class LogsDest: | |
""" | |
For each log type, holds mapping of local rank ids to file paths. | |
""" | |
stdouts: Dict[int, str] = field(default_factory=dict) | |
stderrs: Dict[int, str] = field(default_factory=dict) | |
tee_stdouts: Dict[int, str] = field(default_factory=dict) | |
tee_stderrs: Dict[int, str] = field(default_factory=dict) | |
error_files: Dict[int, str] = field(default_factory=dict) | |
class LogsSpecs(ABC): | |
""" | |
Defines logs processing and redirection for each worker process. | |
Args: | |
log_dir: | |
Base directory where logs will be written. | |
redirects: | |
Streams to redirect to files. Pass a single ``Std`` | |
enum to redirect for all workers, or a mapping keyed | |
by local_rank to selectively redirect. | |
tee: | |
Streams to duplicate to stdout/stderr. | |
Pass a single ``Std`` enum to duplicate streams for all workers, | |
or a mapping keyed by local_rank to selectively duplicate. | |
""" | |
def __init__( | |
self, | |
log_dir: Optional[str] = None, | |
redirects: Union[Std, Dict[int, Std]] = Std.NONE, | |
tee: Union[Std, Dict[int, Std]] = Std.NONE, | |
local_ranks_filter: Optional[Set[int]] = None, | |
) -> None: | |
self._root_log_dir = log_dir | |
self._redirects = redirects | |
self._tee = tee | |
self._local_ranks_filter = local_ranks_filter | |
def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: | |
""" | |
Given the environment variables, builds destination of log files for each of the local ranks. | |
Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: | |
:func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. | |
""" | |
pass | |
def root_log_dir(self) -> str: | |
pass | |
class DefaultLogsSpecs(LogsSpecs): | |
""" | |
Default LogsSpecs implementation: | |
- `log_dir` will be created if it doesn't exist | |
- Generates nested folders for each attempt and rank. | |
""" | |
def __init__( | |
self, | |
log_dir: Optional[str] = None, | |
redirects: Union[Std, Dict[int, Std]] = Std.NONE, | |
tee: Union[Std, Dict[int, Std]] = Std.NONE, | |
local_ranks_filter: Optional[Set[int]] = None, | |
) -> None: | |
if log_dir != os.devnull: | |
if not log_dir: | |
log_dir = tempfile.mkdtemp(prefix="torchelastic_") | |
elif not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
else: | |
if os.path.isfile(log_dir): | |
raise NotADirectoryError(f"log_dir: {log_dir} is a file") | |
super().__init__(log_dir, redirects, tee, local_ranks_filter) | |
# initialized only once | |
self._run_log_dir = None | |
def root_log_dir(self) -> str: | |
return str(self._root_log_dir) | |
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): | |
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") | |
os.makedirs(base_log_dir, exist_ok=True) | |
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) | |
log.info("log directory set to: %s", dir) | |
return dir | |
def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: | |
""" | |
Uses following scheme to build log destination paths: | |
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log` | |
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log` | |
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json` | |
""" | |
nprocs = len(envs) | |
global_env = {} # use only to query properies that are not dependent on a rank | |
if nprocs > 0: | |
global_env = envs[0] | |
else: | |
log.warning("Empty envs map provided when defining logging destinations.") | |
# Keys are always defined, but values can be missing in unit tests | |
run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") | |
restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") | |
attempt_log_dir: str = "" | |
if self._root_log_dir != os.devnull: | |
if not self._run_log_dir: | |
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) | |
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload] | |
shutil.rmtree(attempt_log_dir, ignore_errors=True) | |
os.makedirs(attempt_log_dir) | |
if self._root_log_dir == os.devnull: | |
attempt_log_dir = os.devnull | |
# create subdirs for each local rank in the logs_dir | |
# logs_dir | |
# |- 0 | |
# |- error.json | |
# |- stdout.log | |
# |- stderr.log | |
# |- ... | |
# |- (nprocs-1) | |
redirs = to_map(self._redirects, nprocs) | |
ts = to_map(self._tee, nprocs) | |
# to tee stdout/stderr we first redirect into a file | |
# then tail -f stdout.log/stderr.log so add tee settings to redirects | |
for local_rank, tee_std in ts.items(): | |
redirect_std = redirs[local_rank] | |
redirs[local_rank] = redirect_std | tee_std | |
SYS_STREAM = "" # special case to indicate to output to console | |
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) | |
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) | |
tee_stdouts: Dict[int, str] = {} | |
tee_stderrs: Dict[int, str] = {} | |
error_files = {} | |
for local_rank in range(nprocs): | |
if attempt_log_dir == os.devnull: | |
tee_stdouts[local_rank] = os.devnull | |
tee_stderrs[local_rank] = os.devnull | |
error_files[local_rank] = os.devnull | |
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" | |
else: | |
clogdir = os.path.join(attempt_log_dir, str(local_rank)) | |
os.mkdir(clogdir) | |
rd = redirs[local_rank] | |
if (rd & Std.OUT) == Std.OUT: | |
stdouts[local_rank] = os.path.join(clogdir, "stdout.log") | |
if (rd & Std.ERR) == Std.ERR: | |
stderrs[local_rank] = os.path.join(clogdir, "stderr.log") | |
t = ts[local_rank] | |
if t & Std.OUT == Std.OUT: | |
tee_stdouts[local_rank] = stdouts[local_rank] | |
if t & Std.ERR == Std.ERR: | |
tee_stderrs[local_rank] = stderrs[local_rank] | |
if self._local_ranks_filter and local_rank not in self._local_ranks_filter: | |
# If stream is tee'd, only write to file, but don't tail | |
if local_rank in tee_stdouts: | |
tee_stdouts.pop(local_rank, None) | |
if local_rank in tee_stderrs: | |
tee_stderrs.pop(local_rank, None) | |
# If stream is not redirected, don't print | |
if stdouts[local_rank] == SYS_STREAM: | |
stdouts[local_rank] = os.devnull | |
if stderrs[local_rank] == SYS_STREAM: | |
stderrs[local_rank] = os.devnull | |
error_file = os.path.join(clogdir, "error.json") | |
error_files[local_rank] = error_file | |
log.info("Setting worker%s reply file to: %s", local_rank, error_file) | |
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file | |
return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) | |
def __repr__(self) -> str: | |
return ( | |
f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " | |
f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" | |
) | |
def __eq__(self, other: object) -> bool: | |
if not isinstance(other, DefaultLogsSpecs): | |
return False | |
return ( | |
self._root_log_dir == other._root_log_dir | |
and self._redirects == other._redirects | |
and self._tee == other._tee | |
and self._local_ranks_filter == other._local_ranks_filter | |
) | |
class RunProcsResult: | |
""" | |
Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. | |
Note the following: | |
1. All fields are mapped by local rank | |
2. ``return_values`` - only populated for functions (not the binaries). | |
3. ``stdouts`` - path to stdout.log (empty string if no redirect) | |
4. ``stderrs`` - path to stderr.log (empty string if no redirect) | |
""" | |
return_values: Dict[int, Any] = field(default_factory=dict) | |
failures: Dict[int, ProcessFailure] = field(default_factory=dict) | |
stdouts: Dict[int, str] = field(default_factory=dict) | |
stderrs: Dict[int, str] = field(default_factory=dict) | |
def is_failed(self) -> bool: | |
return len(self.failures) > 0 | |
class PContext(abc.ABC): | |
""" | |
The base class that standardizes operations over a set of processes that are launched via different mechanisms. | |
The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. | |
.. warning:: stdouts and stderrs should ALWAYS be a superset of | |
tee_stdouts and tee_stderrs (respectively) this is b/c | |
tee is implemented as a redirect + tail -f <stdout/stderr.log> | |
""" | |
def __init__( | |
self, | |
name: str, | |
entrypoint: Union[Callable, str], | |
args: Dict[int, Tuple], | |
envs: Dict[int, Dict[str, str]], | |
logs_specs: LogsSpecs, | |
log_line_prefixes: Optional[Dict[int, str]] = None, | |
): | |
self.name = name | |
# validate that all mappings have the same number of keys and | |
# all local ranks are accounted for | |
nprocs = len(args) | |
# TODO log_line_prefixes can be exanded too | |
logs_dest = logs_specs.reify(envs) | |
_validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") | |
_validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") | |
self.entrypoint = entrypoint | |
self.args = args | |
self.envs = envs | |
self.stdouts = logs_dest.stdouts | |
self.stderrs = logs_dest.stderrs | |
self.error_files = logs_dest.error_files | |
self.nprocs = nprocs | |
self._stdout_tail = TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes) | |
self._stderr_tail = TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes) | |
def start(self) -> None: | |
"""Start processes using parameters defined in the constructor.""" | |
signal.signal(signal.SIGTERM, _terminate_process_handler) | |
signal.signal(signal.SIGINT, _terminate_process_handler) | |
if not IS_WINDOWS: | |
signal.signal(signal.SIGHUP, _terminate_process_handler) | |
signal.signal(signal.SIGQUIT, _terminate_process_handler) | |
self._start() | |
self._stdout_tail.start() | |
self._stderr_tail.start() | |
def _start(self) -> None: | |
"""Start processes using strategy defined in a particular context.""" | |
raise NotImplementedError() | |
def _poll(self) -> Optional[RunProcsResult]: | |
""" | |
Poll the run status of the processes running under this context. | |
This method follows an "all-or-nothing" policy and returns | |
a ``RunProcessResults`` object if either all processes complete | |
successfully or any process fails. Returns ``None`` if | |
all processes are still running. | |
""" | |
raise NotImplementedError() | |
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: | |
""" | |
Wait for the specified ``timeout`` seconds, polling every ``period`` seconds | |
for the processes to be done. Returns ``None`` if the processes are still running | |
on timeout expiry. Negative timeout values are interpreted as "wait-forever". | |
A timeout value of zero simply queries the status of the processes (e.g. equivalent | |
to a poll). | |
..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise | |
``SignalException`` when the signals received. It is up to the consumer of the code | |
to properly handle the exception. It is important not to swallow the exception otherwise | |
the process would not terminate. Example of the typical workflow can be: | |
.. code-block:: python | |
pc = start_processes(...) | |
try: | |
pc.wait(1) | |
.. do some other work | |
except SignalException as e: | |
pc.shutdown(e.sigval, timeout=30) | |
If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating | |
received signal. If child processes will not terminate in the timeout time, the process will send | |
the SIGKILL. | |
""" | |
if timeout == 0: | |
return self._poll() | |
if timeout < 0: | |
timeout = sys.maxsize | |
expiry = time.time() + timeout | |
while time.time() < expiry: | |
pr = self._poll() | |
if pr: | |
return pr | |
time.sleep(period) | |
return None | |
def pids(self) -> Dict[int, int]: | |
"""Return pids of processes mapped by their respective local_ranks.""" | |
raise NotImplementedError() | |
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: | |
r""" | |
Terminates all processes managed by this context and cleans up any | |
meta resources (e.g. redirect, error_file files). | |
""" | |
raise NotImplementedError() | |
def close( | |
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30 | |
) -> None: | |
r""" | |
Terminates all processes managed by this context and cleans up any | |
meta resources (e.g. redirect, error_file files). | |
Args: | |
death_sig: Death signal to terminate processes. | |
timeout: Time to wait for processes to finish, if process is | |
still alive after this time, it will be terminated via SIGKILL. | |
""" | |
if not death_sig: | |
death_sig = _get_default_signal() | |
self._close(death_sig=death_sig, timeout=timeout) | |
if self._stdout_tail: | |
self._stdout_tail.stop() | |
if self._stderr_tail: | |
self._stderr_tail.stop() | |
def get_std_cm(std_rd: str, redirect_fn): | |
if IS_WINDOWS or IS_MACOS or not std_rd: | |
return nullcontext() | |
else: | |
return redirect_fn(std_rd) | |
def _wrap( | |
local_rank: int, | |
fn: Callable, | |
args: Dict[int, Tuple], | |
envs: Dict[int, Dict[str, str]], | |
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) | |
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) | |
ret_vals: Dict[int, mp.SimpleQueue], | |
queue_finished_reading_event: synchronize.Event, | |
) -> None: | |
# get the per-rank params up front so we fail fast if no mapping is found | |
args_ = args[local_rank] | |
env_ = envs[local_rank] | |
ret_val_ = ret_vals[local_rank] | |
stdout_rd = stdout_redirects[local_rank] | |
stderr_rd = stderr_redirects[local_rank] | |
stdout_cm = get_std_cm(stdout_rd, redirect_stdout) | |
stderr_cm = get_std_cm(stderr_rd, redirect_stderr) | |
for k, v in env_.items(): | |
os.environ[k] = v | |
with stdout_cm, stderr_cm: | |
ret = record(fn)(*args_) | |
ret_val_.put(ret) | |
queue_finished_reading_event.wait() | |
class MultiprocessContext(PContext): | |
"""``PContext`` holding worker processes invoked as a function.""" | |
def __init__( | |
self, | |
name: str, | |
entrypoint: Callable, | |
args: Dict[int, Tuple], | |
envs: Dict[int, Dict[str, str]], | |
start_method: str, | |
logs_specs: LogsSpecs, | |
log_line_prefixes: Optional[Dict[int, str]] = None, | |
): | |
super().__init__( | |
name, | |
entrypoint, | |
args, | |
envs, | |
logs_specs, | |
log_line_prefixes, | |
) | |
self.start_method = start_method | |
# each ret_val queue will always contain a single element. | |
self._ret_vals = { | |
local_rank: mp.get_context(self.start_method).SimpleQueue() | |
for local_rank in range(self.nprocs) | |
} | |
# see comments in ``join()`` for what this is | |
self._return_values: Dict[int, Any] = {} | |
self._pc: Optional[mp.ProcessContext] = None | |
# Note: set method should ONLY be invoked for the use case when all processes finished | |
# successfully. If any process died on event.wait() calling set() method will deadlock. | |
self._worker_finished_event = mp.get_context(self.start_method).Event() | |
def _start(self): | |
if self._pc: | |
raise ValueError( | |
"The process context already initialized." | |
" Most likely the start method got called twice." | |
) | |
self._pc = mp.start_processes( | |
fn=_wrap, | |
args=( | |
self.entrypoint, | |
self.args, | |
self.envs, | |
self.stdouts, | |
self.stderrs, | |
self._ret_vals, | |
self._worker_finished_event, | |
), | |
nprocs=self.nprocs, | |
join=False, | |
daemon=False, | |
start_method=self.start_method, | |
) | |
def _is_done(self) -> bool: | |
return len(self._return_values) == self.nprocs | |
def _poll(self) -> Optional[RunProcsResult]: | |
assert self._pc is not None # assertion for mypy type checker | |
try: | |
# torch.mp.ProcessContext Throws an Exception if some/all of | |
# worker processes failed | |
# timeout < 0 checks worker status and return immediately | |
# Join will never return success since we use synchronize.Event to wait | |
# for all processes to finish. | |
self._pc.join(-1) | |
# IMPORTANT: we use multiprocessing.Queue to carry worker return values | |
# back to the parent, the worker process will wait before terminating | |
# until all the buffered items are fed by the feeder thread to the underlying | |
# pipe. Hence to prevent deadlocks on large return values, | |
# we opportunistically try queue.get on each join call | |
# See: https://docs.python.org/2/library/multiprocessing.html#all-platforms | |
for local_rank in range(0, self.nprocs): | |
return_queue = self._ret_vals[local_rank] | |
if not return_queue.empty(): | |
# save the return values temporarily into a member var | |
self._return_values[local_rank] = return_queue.get() | |
if self._is_done(): | |
# we should ALWAYS have ALL the return values when all the processes are done | |
self._worker_finished_event.set() | |
# Wait untill all processes are finished. At this point workers finished executing | |
# user function | |
self._pc.join() | |
_validate_full_rank( | |
self._return_values, self.nprocs, "return_value queue" | |
) | |
self.close() | |
return RunProcsResult( | |
return_values=self._return_values, | |
stdouts=self.stdouts, | |
stderrs=self.stderrs, | |
) | |
else: | |
return None | |
except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: | |
failed_local_rank = e.error_index | |
# entrypoint for MultiprocessContext will always be a Callable | |
fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] | |
failed_proc = self._pc.processes[failed_local_rank] | |
error_filepath = self.error_files[failed_local_rank] | |
log.exception( | |
"failed (exitcode: %s)" | |
" local_rank: %s (pid: %s)" | |
" of fn: %s (start_method: %s)", | |
failed_proc.exitcode, | |
failed_local_rank, e.pid, | |
fn_name, self.start_method, | |
) | |
self.close() | |
return RunProcsResult( | |
failures={ | |
failed_local_rank: ProcessFailure( | |
local_rank=failed_local_rank, | |
pid=e.pid, | |
exitcode=failed_proc.exitcode, | |
error_file=error_filepath, | |
) | |
}, | |
stdouts=self.stdouts, | |
stderrs=self.stderrs, | |
) | |
def pids(self) -> Dict[int, int]: | |
assert self._pc is not None # assertion for mypy type checking | |
return dict(enumerate(self._pc.pids())) | |
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: | |
if not self._pc: | |
return | |
for proc in self._pc.processes: | |
if proc.is_alive(): | |
log.warning("Closing process %s via signal %s", proc.pid, death_sig.name) | |
try: | |
os.kill(proc.pid, death_sig) | |
except ProcessLookupError: | |
# If the process exited because of some reason, | |
# `ProcessLookupError` will be raised, it is safe to ignore it. | |
pass | |
end = time.monotonic() + timeout | |
for proc in self._pc.processes: | |
time_to_wait = end - time.monotonic() | |
if time_to_wait <= 0: | |
break | |
proc.join(time_to_wait) | |
for proc in self._pc.processes: | |
if proc.is_alive(): | |
log.warning( | |
"Unable to shutdown process %s via %s, forcefully exiting via %s", | |
proc.pid, death_sig, _get_kill_signal() | |
) | |
try: | |
os.kill(proc.pid, _get_kill_signal()) | |
except ProcessLookupError: | |
# If the process exited because of some reason, | |
# `ProcessLookupError` will be raised, it is safe to ignore it. | |
pass | |
proc.join() | |
class SubprocessContext(PContext): | |
"""``PContext`` holding worker processes invoked as a binary.""" | |
def __init__( | |
self, | |
name: str, | |
entrypoint: str, | |
args: Dict[int, Tuple], | |
envs: Dict[int, Dict[str, str]], | |
logs_specs: LogsSpecs, | |
log_line_prefixes: Optional[Dict[int, str]] = None, | |
): | |
super().__init__( | |
name, | |
entrypoint, | |
args, | |
envs, | |
logs_specs, | |
log_line_prefixes, | |
) | |
# state vector; _vdone[local_rank] -> is local_rank finished or not | |
self._running_local_ranks: Set[int] = set(range(self.nprocs)) | |
self._failures: Dict[int, ProcessFailure] = {} | |
self.subprocess_handlers: Dict[int, SubprocessHandler] = {} | |
def _start(self): | |
if self.subprocess_handlers: | |
raise ValueError( | |
"The subprocess handlers already initialized. Most likely the start method got called twice." | |
) | |
self.subprocess_handlers = { | |
local_rank: get_subprocess_handler( | |
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str | |
args=self.args[local_rank], | |
env=self.envs[local_rank], | |
stdout=self.stdouts[local_rank], | |
stderr=self.stderrs[local_rank], | |
local_rank_id=local_rank, | |
) | |
for local_rank in range(self.nprocs) | |
} | |
def _poll(self) -> Optional[RunProcsResult]: | |
done_local_ranks = set() | |
for local_rank in self._running_local_ranks: | |
handler = self.subprocess_handlers[local_rank] | |
exitcode = handler.proc.poll() | |
if exitcode is not None: | |
done_local_ranks.add(local_rank) | |
if exitcode != 0: # failed or signaled | |
self._failures[local_rank] = ProcessFailure( | |
local_rank=local_rank, | |
pid=handler.proc.pid, | |
exitcode=exitcode, | |
error_file=self.error_files[local_rank], | |
) | |
# else: --> succeeded; nothing to do | |
self._running_local_ranks.difference_update(done_local_ranks) | |
# if ALL procs are finished or ANY have failed | |
if not self._running_local_ranks or self._failures: | |
self.close() # terminate all running procs | |
result = RunProcsResult( | |
failures=self._failures, | |
stdouts=self.stdouts, | |
stderrs=self.stderrs, | |
) | |
if result.is_failed(): | |
first_failure = min(result.failures.values(), key=lambda f: f.timestamp) | |
log.error( | |
"failed (exitcode: %s)" | |
" local_rank: %s (pid: %s)" | |
" of binary: %s", | |
first_failure.exitcode, first_failure.local_rank, first_failure.pid, self.entrypoint | |
) | |
else: | |
# Populate return with dummy values. This provides consistency with MultiprocessingHandler | |
result.return_values = dict.fromkeys(range(self.nprocs)) | |
return result | |
else: # there are no failures and procs still running | |
return None | |
def pids(self) -> Dict[int, int]: | |
return { | |
local_rank: sh.proc.pid | |
for local_rank, sh in self.subprocess_handlers.items() | |
} | |
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: | |
if not self.subprocess_handlers: | |
return | |
for handler in self.subprocess_handlers.values(): | |
if handler.proc.poll() is None: | |
log.warning( | |
"Sending process %s closing signal %s", handler.proc.pid, death_sig.name | |
) | |
handler.close(death_sig=death_sig) | |
end = time.monotonic() + timeout | |
for handler in self.subprocess_handlers.values(): | |
time_to_wait = end - time.monotonic() | |
if time_to_wait <= 0: | |
break | |
try: | |
handler.proc.wait(time_to_wait) | |
except subprocess.TimeoutExpired: | |
# Ignore the timeout expired exception, since | |
# the child process will be forcefully terminated via SIGKILL | |
pass | |
for handler in self.subprocess_handlers.values(): | |
if handler.proc.poll() is None: | |
log.warning( | |
"Unable to shutdown process %s via %s, forcefully exiting via %s", | |
handler.proc.pid, death_sig, _get_kill_signal() | |
) | |
handler.close(death_sig=_get_kill_signal()) | |
handler.proc.wait() | |