m7n's picture
first commit
d1ed09d
raw
history blame
6.69 kB
from __future__ import annotations
import logging
import ssl
import warnings
import weakref
from contextlib import suppress
import tlz
from tornado.httpserver import HTTPServer
import dask
from distributed.comm import get_address_host, get_tcp_server_addresses
from distributed.core import Server
from distributed.http.routing import RoutingApplication
from distributed.utils import DequeHandler, clean_dashboard_address
from distributed.versions import get_versions
class ServerNode(Server):
"""
Base class for server nodes in a distributed cluster.
"""
# TODO factor out security, listening, services, etc. here
# XXX avoid inheriting from Server? there is some large potential for confusion
# between base and derived attribute namespaces...
def versions(self, packages=None):
return get_versions(packages=packages)
def start_services(self, default_listen_ip):
if default_listen_ip == "0.0.0.0":
default_listen_ip = "" # for IPV6
for k, v in self.service_specs.items():
listen_ip = None
if isinstance(k, tuple):
k, port = k
else:
port = 0
if isinstance(port, str):
port = port.split(":")
if isinstance(port, (tuple, list)):
if len(port) == 2:
listen_ip, port = (port[0], int(port[1]))
elif len(port) == 1:
[listen_ip], port = port, 0
else:
raise ValueError(port)
if isinstance(v, tuple):
v, kwargs = v
else:
kwargs = {}
try:
service = v(self, io_loop=self.loop, **kwargs)
service.listen(
(listen_ip if listen_ip is not None else default_listen_ip, port)
)
self.services[k] = service
except Exception as e:
warnings.warn(
f"\nCould not launch service '{k}' on port {port}. "
+ "Got the following message:\n\n"
+ str(e),
stacklevel=3,
)
def stop_services(self):
if hasattr(self, "http_application"):
for application in self.http_application.applications:
if hasattr(application, "stop") and callable(application.stop):
application.stop()
for service in self.services.values():
service.stop()
@property
def service_ports(self):
return {k: v.port for k, v in self.services.items()}
def _setup_logging(self, logger: logging.Logger) -> None:
self._deque_handler = DequeHandler(
n=dask.config.get("distributed.admin.log-length")
)
self._deque_handler.setFormatter(
logging.Formatter(dask.config.get("distributed.admin.log-format"))
)
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)
def get_logs(self, start=0, n=None, timestamps=False):
"""
Fetch log entries for this node
Parameters
----------
start : float, optional
A time (in seconds) to begin filtering log entries from
n : int, optional
Maximum number of log entries to return from filtered results
timestamps : bool, default False
Do we want log entries to include the time they were generated?
Returns
-------
List of tuples containing the log level, message, and (optional) timestamp for each filtered entry, newest first
"""
deque_handler = self._deque_handler
L = []
for count, msg in enumerate(reversed(deque_handler.deque)):
if n and count >= n or msg.created < start:
break
if timestamps:
L.append((msg.created, msg.levelname, deque_handler.format(msg)))
else:
L.append((msg.levelname, deque_handler.format(msg)))
return L
def start_http_server(
self, routes, dashboard_address, default_port=0, ssl_options=None
):
"""This creates an HTTP Server running on this node"""
self.http_application = RoutingApplication(routes)
# TLS configuration
tls_key = dask.config.get("distributed.scheduler.dashboard.tls.key")
tls_cert = dask.config.get("distributed.scheduler.dashboard.tls.cert")
tls_ca_file = dask.config.get("distributed.scheduler.dashboard.tls.ca-file")
if tls_cert:
ssl_options = ssl.create_default_context(
cafile=tls_ca_file, purpose=ssl.Purpose.CLIENT_AUTH
)
ssl_options.load_cert_chain(tls_cert, keyfile=tls_key)
self.http_server = HTTPServer(self.http_application, ssl_options=ssl_options)
http_addresses = clean_dashboard_address(dashboard_address or default_port)
for http_address in http_addresses:
if http_address["address"] is None:
address = self._start_address
if isinstance(address, (list, tuple)):
address = address[0]
if address:
with suppress(ValueError):
http_address["address"] = get_address_host(address)
change_port = False
retries_left = 3
while True:
try:
if not change_port:
self.http_server.listen(**http_address)
else:
self.http_server.listen(**tlz.merge(http_address, {"port": 0}))
break
except Exception:
change_port = True
retries_left = retries_left - 1
if retries_left < 1:
raise
bound_addresses = get_tcp_server_addresses(self.http_server)
# If more than one address is configured we just use the first here
self.http_server.port = bound_addresses[0][1]
self.services["dashboard"] = self.http_server
# Warn on port changes
for expected, actual in zip(
[a["port"] for a in http_addresses], [b[1] for b in bound_addresses]
):
if expected != actual and expected > 0:
warnings.warn(
f"Port {expected} is already in use.\n"
"Perhaps you already have a cluster running?\n"
f"Hosting the HTTP server on port {actual} instead"
)