Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import contextlib | |
| import logging | |
| from collections import deque | |
| from typing import Any | |
| from tornado import gen, locks | |
| from tornado.ioloop import IOLoop | |
| import dask | |
| from dask.utils import parse_timedelta | |
| from distributed.core import CommClosedError | |
| from distributed.metrics import time | |
| logger = logging.getLogger(__name__) | |
| class BatchedSend: | |
| """Batch messages in batches on a stream | |
| This takes an IOStream and an interval (in ms) and ensures that we send no | |
| more than one message every interval milliseconds. We send lists of | |
| messages. | |
| Batching several messages at once helps performance when sending | |
| a myriad of tiny messages. | |
| Examples | |
| -------- | |
| >>> stream = await connect(address) | |
| >>> bstream = BatchedSend(interval='10 ms') | |
| >>> bstream.start(stream) | |
| >>> bstream.send('Hello,') | |
| >>> bstream.send('world!') | |
| On the other side, the recipient will get a message like the following:: | |
| ['Hello,', 'world!'] | |
| """ | |
| # XXX why doesn't BatchedSend follow either the IOStream or Comm API? | |
| def __init__(self, interval, loop=None, serializers=None): | |
| # XXX is the loop arg useful? | |
| self.loop = loop or IOLoop.current() | |
| self.interval = parse_timedelta(interval, default="ms") | |
| self.waker = locks.Event() | |
| self.stopped = locks.Event() | |
| self.please_stop = False | |
| self.buffer = [] | |
| self.comm = None | |
| self.message_count = 0 | |
| self.batch_count = 0 | |
| self.byte_count = 0 | |
| self.next_deadline = None | |
| self.recent_message_log = deque( | |
| maxlen=dask.config.get("distributed.comm.recent-messages-log-length") | |
| ) | |
| self.serializers = serializers | |
| self._consecutive_failures = 0 | |
| def start(self, comm): | |
| self.comm = comm | |
| self.loop.add_callback(self._background_send) | |
| def closed(self): | |
| return self.comm and self.comm.closed() | |
| def __repr__(self): | |
| if self.closed(): | |
| return "<BatchedSend: closed>" | |
| else: | |
| return "<BatchedSend: %d in buffer>" % len(self.buffer) | |
| __str__ = __repr__ | |
| def _background_send(self): | |
| while not self.please_stop: | |
| try: | |
| yield self.waker.wait(self.next_deadline) | |
| self.waker.clear() | |
| except gen.TimeoutError: | |
| pass | |
| if not self.buffer: | |
| # Nothing to send | |
| self.next_deadline = None | |
| continue | |
| if self.next_deadline is not None and time() < self.next_deadline: | |
| # Send interval not expired yet | |
| continue | |
| payload, self.buffer = self.buffer, [] | |
| self.batch_count += 1 | |
| self.next_deadline = time() + self.interval | |
| try: | |
| # NOTE: Since `BatchedSend` doesn't have a handle on the running | |
| # `_background_send` coroutine, the only thing with a reference to this | |
| # coroutine is the event loop itself. If the event loop stops while | |
| # we're waiting on a `write`, the `_background_send` coroutine object | |
| # may be garbage collected. If that happens, the `yield coro` will raise | |
| # `GeneratorExit`. But because this is an old-school `gen.coroutine`, | |
| # and we're using `yield` and not `await`, the `write` coroutine object | |
| # will not actually have been awaited, and it will remain sitting around | |
| # for someone to retrieve it. At interpreter exit, this will warn | |
| # sommething like `RuntimeWarning: coroutine 'TCP.write' was never | |
| # awaited`. By using the `closing` contextmanager, the `write` coroutine | |
| # object is always cleaned up, even if `yield` raises `GeneratorExit`. | |
| with contextlib.closing( | |
| self.comm.write( | |
| payload, serializers=self.serializers, on_error="raise" | |
| ) | |
| ) as coro: | |
| nbytes = yield coro | |
| if nbytes < 1e6: | |
| self.recent_message_log.append(payload) | |
| else: | |
| self.recent_message_log.append("large-message") | |
| self.byte_count += nbytes | |
| except CommClosedError: | |
| logger.info("Batched Comm Closed %r", self.comm, exc_info=True) | |
| break | |
| except Exception: | |
| # We cannot safely retry self.comm.write, as we have no idea | |
| # what (if anything) was actually written to the underlying stream. | |
| # Re-writing messages could result in complete garbage (e.g. if a frame | |
| # header has been written, but not the frame payload), therefore | |
| # the only safe thing to do here is to abort the stream without | |
| # any attempt to re-try `write`. | |
| logger.exception("Error in batched write") | |
| break | |
| finally: | |
| payload = None # lose ref | |
| else: | |
| # nobreak. We've been gracefully closed. | |
| self.stopped.set() | |
| return | |
| # If we've reached here, it means `break` was hit above and | |
| # there was an exception when using `comm`. | |
| # We can't close gracefully via `.close()` since we can't send messages. | |
| # So we just abort. | |
| # This means that any messages in our buffer our lost. | |
| # To propagate exceptions, we rely on subsequent `BatchedSend.send` | |
| # calls to raise CommClosedErrors. | |
| self.stopped.set() | |
| self.abort() | |
| def send(self, *msgs: Any) -> None: | |
| """Schedule a message for sending to the other side | |
| This completes quickly and synchronously | |
| """ | |
| if self.comm is not None and self.comm.closed(): | |
| raise CommClosedError(f"Comm {self.comm!r} already closed.") | |
| self.message_count += len(msgs) | |
| self.buffer.extend(msgs) | |
| # Avoid spurious wakeups if possible | |
| if self.next_deadline is None: | |
| self.waker.set() | |
| def close(self, timeout=None): | |
| """Flush existing messages and then close comm | |
| If set, raises `tornado.util.TimeoutError` after a timeout. | |
| """ | |
| if self.comm is None: | |
| return | |
| self.please_stop = True | |
| self.waker.set() | |
| yield self.stopped.wait(timeout=timeout) | |
| if not self.comm.closed(): | |
| try: | |
| if self.buffer: | |
| self.buffer, payload = [], self.buffer | |
| # See note in `_background_send` for explanation of `closing`. | |
| with contextlib.closing( | |
| self.comm.write( | |
| payload, serializers=self.serializers, on_error="raise" | |
| ) | |
| ) as coro: | |
| yield coro | |
| except CommClosedError: | |
| pass | |
| yield self.comm.close() | |
| def abort(self): | |
| if self.comm is None: | |
| return | |
| self.please_stop = True | |
| self.buffer = [] | |
| self.waker.set() | |
| if not self.comm.closed(): | |
| self.comm.abort() | |