Spaces:
Running
Running
File size: 9,889 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import logging
import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set
__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires']
log = logging.getLogger(__name__)
class TimerRequest:
"""
Data object representing a countdown timer acquisition and release
that is used between the ``TimerClient`` and ``TimerServer``.
A negative ``expiration_time`` should be interpreted as a "release"
request.
.. note:: the type of ``worker_id`` is implementation specific.
It is whatever the TimerServer and TimerClient implementations
have on to uniquely identify a worker.
"""
__slots__ = ["worker_id", "scope_id", "expiration_time"]
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
self.worker_id = worker_id
self.scope_id = scope_id
self.expiration_time = expiration_time
def __eq__(self, other):
if isinstance(other, TimerRequest):
return (
self.worker_id == other.worker_id
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
)
return False
class TimerClient(abc.ABC):
"""
Client library to acquire and release countdown timers by communicating
with the TimerServer.
"""
@abc.abstractmethod
def acquire(self, scope_id: str, expiration_time: float) -> None:
"""
Acquires a timer for the worker that holds this client object
given the scope_id and expiration_time. Typically registers
the timer with the TimerServer.
"""
pass
@abc.abstractmethod
def release(self, scope_id: str):
"""
Releases the timer for the ``scope_id`` on the worker this
client represents. After this method is
called, the countdown timer on the scope is no longer in effect.
"""
pass
class RequestQueue(abc.ABC):
"""
Consumer queue holding timer acquisition/release requests
"""
@abc.abstractmethod
def size(self) -> int:
"""
Returns the size of the queue at the time this method is called.
Note that by the time ``get`` is called the size of the queue
may have increased. The size of the queue should not decrease
until the ``get`` method is called. That is, the following assertion
should hold:
size = q.size()
res = q.get(size, timeout=0)
assert size == len(res)
-- or --
size = q.size()
res = q.get(size * 2, timeout=1)
assert size <= len(res) <= size * 2
"""
pass
@abc.abstractmethod
def get(self, size: int, timeout: float) -> List[TimerRequest]:
"""
Gets up to ``size`` number of timer requests in a blocking fashion
(no more than ``timeout`` seconds).
"""
pass
class TimerServer(abc.ABC):
"""
Entity that monitors active timers and expires them
in a timely fashion. This server is responsible for
reaping workers that have expired timers.
"""
def __init__(
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
):
"""
:param request_queue: Consumer ``RequestQueue``
:param max_interval: max time (in seconds) to wait
for an item in the request_queue
:param daemon: whether to run the watchdog thread as a daemon
"""
super().__init__()
self._request_queue = request_queue
self._max_interval = max_interval
self._daemon = daemon
self._watchdog_thread: Optional[threading.Thread] = None
self._stop_signaled = False
@abc.abstractmethod
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
"""
Processes the incoming timer requests and registers them with the server.
The timer request can either be a acquire-timer or release-timer request.
Timer requests with a negative expiration_time should be interpreted
as a release-timer request.
"""
pass
@abc.abstractmethod
def clear_timers(self, worker_ids: Set[Any]) -> None:
"""
Clears all timers for the given ``worker_ids``.
"""
pass
@abc.abstractmethod
def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
"""
Returns all expired timers for each worker_id. An expired timer
is a timer for which the expiration_time is less than or equal to
the provided deadline.
"""
pass
@abc.abstractmethod
def _reap_worker(self, worker_id: Any) -> bool:
"""
Reaps the given worker. Returns True if the worker has been
successfully reaped, False otherwise. If any uncaught exception
is thrown from this method, the worker is considered reaped
and all associated timers will be removed.
"""
def _reap_worker_no_throw(self, worker_id: Any) -> bool:
"""
Wraps ``_reap_worker(worker_id)``, if an uncaught exception is
thrown, then it considers the worker as reaped.
"""
try:
return self._reap_worker(worker_id)
except Exception:
log.exception(
"Uncaught exception thrown from _reap_worker(), "
"check that the implementation correctly catches exceptions",
)
return True
def _watchdog_loop(self):
while not self._stop_signaled:
try:
self._run_watchdog()
except Exception:
log.exception("Error running watchdog")
def _run_watchdog(self):
batch_size = max(1, self._request_queue.size())
timer_requests = self._request_queue.get(batch_size, self._max_interval)
self.register_timers(timer_requests)
now = time.time()
reaped_worker_ids = set()
for worker_id, expired_timers in self.get_expired_timers(now).items():
log.info(
"Reaping worker_id=[%s]."
" Expired timers: %s",
worker_id, self._get_scopes(expired_timers)
)
if self._reap_worker_no_throw(worker_id):
log.info("Successfully reaped worker=[%s]", worker_id)
reaped_worker_ids.add(worker_id)
else:
log.error(
"Error reaping worker=[%s]. Will retry on next watchdog.", worker_id
)
self.clear_timers(reaped_worker_ids)
def _get_scopes(self, timer_requests):
return [r.scope_id for r in timer_requests]
def start(self) -> None:
log.info(
"Starting %s..."
" max_interval=%s,"
" daemon=%s",
type(self).__name__, self._max_interval, self._daemon
)
self._watchdog_thread = threading.Thread(
target=self._watchdog_loop, daemon=self._daemon
)
log.info("Starting watchdog thread...")
self._watchdog_thread.start()
def stop(self) -> None:
log.info("Stopping %s", type(self).__name__)
self._stop_signaled = True
if self._watchdog_thread:
log.info("Stopping watchdog thread...")
self._watchdog_thread.join(self._max_interval)
self._watchdog_thread = None
else:
log.info("No watchdog thread running, doing nothing")
_timer_client: Optional[TimerClient] = None
def configure(timer_client: TimerClient):
"""
Configures a timer client. Must be called before using ``expires``.
"""
global _timer_client
_timer_client = timer_client
log.info("Timer client configured to: %s", type(_timer_client).__name__)
@contextmanager
def expires(
after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
):
"""
Acquires a countdown timer that expires in ``after`` seconds from now,
unless the code-block that it wraps is finished within the timeframe.
When the timer expires, this worker is eligible to be reaped. The
exact meaning of "reaped" depends on the client implementation. In
most cases, reaping means to terminate the worker process.
Note that the worker is NOT guaranteed to be reaped at exactly
``time.now() + after``, but rather the worker is "eligible" for being
reaped and the ``TimerServer`` that the client talks to will ultimately
make the decision when and how to reap the workers with expired timers.
Usage::
torch.distributed.elastic.timer.configure(LocalTimerClient())
with expires(after=10):
torch.distributed.all_reduce(...)
"""
if client is None:
if _timer_client is None:
raise RuntimeError("Configure timer client before using countdown timers.")
client = _timer_client
if scope is None:
# grab the caller file + lineno
caller = getframeinfo(stack()[1][0])
scope = f"{caller.filename}#{caller.lineno}"
expiration = time.time() + after
client.acquire(scope, expiration)
try:
yield
finally:
client.release(scope)
|