Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import sys | |
import uuid | |
from dataclasses import dataclass, field | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch.distributed.elastic.rendezvous.registry as rdzv_registry | |
from torch.distributed.elastic import events, metrics | |
from torch.distributed.elastic.agent.server.api import WorkerSpec | |
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent | |
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException | |
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError | |
from torch.distributed.elastic.rendezvous import RendezvousParameters | |
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint | |
from torch.distributed.elastic.utils.logging import get_logger | |
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] | |
logger = get_logger(__name__) | |
class LaunchConfig: | |
""" | |
Creates a rendezvous config. | |
Args: | |
min_nodes: Minimum amount of nodes that the user function will | |
be launched on. Elastic agent ensures that the user | |
function start only when the min_nodes amount enters | |
the rendezvous. | |
max_nodes: Maximum amount of nodes that the user function | |
will be launched on. | |
nproc_per_node: On each node the elastic agent will launch | |
this amount of workers that will execute user | |
defined function. | |
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). | |
rdzv_endpoint: The endpoint of the rdzv sync. storage. | |
rdzv_configs: Key, value pair that specifies rendezvous specific configuration. | |
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going | |
to be removed in future versions, see the note below. The default timeout is 900 seconds. | |
run_id: The unique run id of the job (if not passed a unique one will be | |
deduced from run environment - flow workflow id in flow - or auto generated). | |
role: User defined role of the worker (defaults to "trainer"). | |
max_restarts: The maximum amount of restarts that elastic agent will conduct | |
on workers before failure. | |
monitor_interval: The interval in seconds that is used by the elastic_agent | |
as a period of monitoring workers. | |
start_method: The method is used by the elastic agent to start the | |
workers (spawn, fork, forkserver). | |
metrics_cfg: configuration to initialize metrics. | |
local_addr: address of the local node if any. If not set, a lookup on the local | |
machine's FQDN will be performed. | |
local_ranks_filter: ranks for which to show logs in console. If not set, show from all. | |
..note: | |
`rdzv_timeout` is a legacy argument that will be removed in future. | |
Set the timeout via `rdzv_configs['timeout']` | |
""" | |
min_nodes: int | |
max_nodes: int | |
nproc_per_node: int | |
logs_specs: Optional[LogsSpecs] = None | |
run_id: str = "" | |
role: str = "default_role" | |
rdzv_endpoint: str = "" | |
rdzv_backend: str = "etcd" | |
rdzv_configs: Dict[str, Any] = field(default_factory=dict) | |
rdzv_timeout: int = -1 | |
max_restarts: int = 3 | |
monitor_interval: float = 30 | |
start_method: str = "spawn" | |
log_line_prefix_template: Optional[str] = None | |
metrics_cfg: Dict[str, str] = field(default_factory=dict) | |
local_addr: Optional[str] = None | |
def __post_init__(self): | |
default_timeout = 900 | |
if self.rdzv_timeout != -1: | |
self.rdzv_configs["timeout"] = self.rdzv_timeout | |
elif "timeout" not in self.rdzv_configs: | |
self.rdzv_configs["timeout"] = default_timeout | |
# Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage | |
if self.logs_specs is None: | |
self.logs_specs = DefaultLogsSpecs() | |
class elastic_launch: | |
""" | |
Launches an torchelastic agent on the container that invoked the entrypoint. | |
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ | |
``entrypoint`` can be a function or a command. | |
2. The return value is a map of each worker's output mapped | |
by their respective global rank. | |
Usage | |
:: | |
def worker_fn(foo): | |
# ... | |
def main(): | |
# entrypoint is a function. | |
outputs = elastic_launch(LaunchConfig, worker_fn)(foo) | |
# return rank 0's output | |
return outputs[0] | |
# entrypoint is a command and ``script.py`` is the python module. | |
outputs = elastic_launch(LaunchConfig, "script.py")(args) | |
outputs = elastic_launch(LaunchConfig, "python")("script.py") | |
""" | |
def __init__( | |
self, | |
config: LaunchConfig, | |
entrypoint: Union[Callable, str, None], | |
): | |
self._config = config | |
self._entrypoint = entrypoint | |
def __call__(self, *args): | |
return launch_agent(self._config, self._entrypoint, list(args)) | |
def _get_entrypoint_name( | |
entrypoint: Union[Callable, str, None], args: List[Any] | |
) -> str: | |
"""Retrieve entrypoint name with the rule: | |
1. If entrypoint is a function, use ``entrypoint.__qualname__``. | |
2. If entrypoint is a string, check its value: | |
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` | |
which does not start with hifen letter (for example, "-u" will be skipped). | |
2.2 otherwise, use ``entrypoint`` value. | |
3. Otherwise, return empty string. | |
""" | |
if isinstance(entrypoint, Callable): # type: ignore[arg-type] | |
return entrypoint.__name__ # type: ignore[union-attr] | |
elif isinstance(entrypoint, str): | |
if entrypoint == sys.executable: | |
return next((arg for arg in args if arg[0] != "-"), "") | |
else: | |
return entrypoint | |
else: | |
return "" | |
def _get_addr_and_port( | |
rdzv_parameters: RendezvousParameters, | |
) -> Tuple[Optional[str], Optional[int]]: | |
if rdzv_parameters.backend != "static": | |
return (None, None) | |
endpoint = rdzv_parameters.endpoint | |
endpoint = endpoint.strip() | |
if not endpoint: | |
raise ValueError( | |
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port" | |
) | |
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) | |
if master_port == -1: | |
raise ValueError( | |
f"port is missing in endpoint: {endpoint}. Try to specify --master-port" | |
) | |
return (master_addr, master_port) | |
def launch_agent( | |
config: LaunchConfig, | |
entrypoint: Union[Callable, str, None], | |
args: List[Any], | |
) -> Dict[int, Any]: | |
if not config.run_id: | |
run_id = str(uuid.uuid4().int) | |
logger.warning("config has no run_id, generated a random run_id: %s", run_id) | |
config.run_id = run_id | |
entrypoint_name = _get_entrypoint_name(entrypoint, args) | |
logger.info( | |
"Starting elastic_operator with launch configs:\n" | |
" entrypoint : %(entrypoint)s\n" | |
" min_nodes : %(min_nodes)s\n" | |
" max_nodes : %(max_nodes)s\n" | |
" nproc_per_node : %(nproc_per_node)s\n" | |
" run_id : %(run_id)s\n" | |
" rdzv_backend : %(rdzv_backend)s\n" | |
" rdzv_endpoint : %(rdzv_endpoint)s\n" | |
" rdzv_configs : %(rdzv_configs)s\n" | |
" max_restarts : %(max_restarts)s\n" | |
" monitor_interval : %(monitor_interval)s\n" | |
" log_dir : %(log_dir)s\n" | |
" metrics_cfg : %(metrics_cfg)s\n", | |
{ | |
"entrypoint": entrypoint_name, | |
"min_nodes": config.min_nodes, | |
"max_nodes": config.max_nodes, | |
"nproc_per_node": config.nproc_per_node, | |
"run_id": config.run_id, | |
"rdzv_backend": config.rdzv_backend, | |
"rdzv_endpoint": config.rdzv_endpoint, | |
"rdzv_configs": config.rdzv_configs, | |
"max_restarts": config.max_restarts, | |
"monitor_interval": config.monitor_interval, | |
"log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] | |
"metrics_cfg": config.metrics_cfg | |
} | |
) | |
rdzv_parameters = RendezvousParameters( | |
backend=config.rdzv_backend, | |
endpoint=config.rdzv_endpoint, | |
run_id=config.run_id, | |
min_nodes=config.min_nodes, | |
max_nodes=config.max_nodes, | |
local_addr=config.local_addr, | |
**config.rdzv_configs, | |
) | |
master_addr, master_port = _get_addr_and_port(rdzv_parameters) | |
spec = WorkerSpec( | |
role=config.role, | |
local_world_size=config.nproc_per_node, | |
entrypoint=entrypoint, | |
args=tuple(args), | |
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), | |
max_restarts=config.max_restarts, | |
monitor_interval=config.monitor_interval, | |
master_addr=master_addr, | |
master_port=master_port, | |
local_addr=config.local_addr, | |
) | |
agent = LocalElasticAgent( | |
spec=spec, | |
logs_specs=config.logs_specs, # type: ignore[arg-type] | |
start_method=config.start_method, | |
log_line_prefix_template=config.log_line_prefix_template, | |
) | |
shutdown_rdzv = True | |
try: | |
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) | |
result = agent.run() | |
# records that agent.run() has succeeded NOT that workers have succeeded | |
events.record(agent.get_event_succeeded()) | |
if result.is_failed(): | |
# ChildFailedError is treated specially by @record | |
# if the error files for the failed children exist | |
# @record will copy the first error (root cause) | |
# to the error file of the launcher process. | |
raise ChildFailedError( | |
name=entrypoint_name, | |
failures=result.failures, | |
) | |
return result.return_values | |
except ChildFailedError: | |
raise | |
except SignalException: | |
# when the agent dies with a signal do NOT shutdown the rdzv_handler | |
# since this closes the rendezvous on this rdzv_id permanently and | |
# prevents any additional scaling events | |
shutdown_rdzv = False | |
events.record(agent.get_event_failed()) | |
raise | |
except Exception: | |
events.record(agent.get_event_failed()) | |
raise | |
finally: | |
if shutdown_rdzv: | |
spec.rdzv_handler.shutdown() | |