Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,687 Bytes
d1ed09d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from __future__ import annotations
import logging
import ssl
import warnings
import weakref
from contextlib import suppress
import tlz
from tornado.httpserver import HTTPServer
import dask
from distributed.comm import get_address_host, get_tcp_server_addresses
from distributed.core import Server
from distributed.http.routing import RoutingApplication
from distributed.utils import DequeHandler, clean_dashboard_address
from distributed.versions import get_versions
class ServerNode(Server):
"""
Base class for server nodes in a distributed cluster.
"""
# TODO factor out security, listening, services, etc. here
# XXX avoid inheriting from Server? there is some large potential for confusion
# between base and derived attribute namespaces...
def versions(self, packages=None):
return get_versions(packages=packages)
def start_services(self, default_listen_ip):
if default_listen_ip == "0.0.0.0":
default_listen_ip = "" # for IPV6
for k, v in self.service_specs.items():
listen_ip = None
if isinstance(k, tuple):
k, port = k
else:
port = 0
if isinstance(port, str):
port = port.split(":")
if isinstance(port, (tuple, list)):
if len(port) == 2:
listen_ip, port = (port[0], int(port[1]))
elif len(port) == 1:
[listen_ip], port = port, 0
else:
raise ValueError(port)
if isinstance(v, tuple):
v, kwargs = v
else:
kwargs = {}
try:
service = v(self, io_loop=self.loop, **kwargs)
service.listen(
(listen_ip if listen_ip is not None else default_listen_ip, port)
)
self.services[k] = service
except Exception as e:
warnings.warn(
f"\nCould not launch service '{k}' on port {port}. "
+ "Got the following message:\n\n"
+ str(e),
stacklevel=3,
)
def stop_services(self):
if hasattr(self, "http_application"):
for application in self.http_application.applications:
if hasattr(application, "stop") and callable(application.stop):
application.stop()
for service in self.services.values():
service.stop()
@property
def service_ports(self):
return {k: v.port for k, v in self.services.items()}
def _setup_logging(self, logger: logging.Logger) -> None:
self._deque_handler = DequeHandler(
n=dask.config.get("distributed.admin.log-length")
)
self._deque_handler.setFormatter(
logging.Formatter(dask.config.get("distributed.admin.log-format"))
)
logger.addHandler(self._deque_handler)
weakref.finalize(self, logger.removeHandler, self._deque_handler)
def get_logs(self, start=0, n=None, timestamps=False):
"""
Fetch log entries for this node
Parameters
----------
start : float, optional
A time (in seconds) to begin filtering log entries from
n : int, optional
Maximum number of log entries to return from filtered results
timestamps : bool, default False
Do we want log entries to include the time they were generated?
Returns
-------
List of tuples containing the log level, message, and (optional) timestamp for each filtered entry, newest first
"""
deque_handler = self._deque_handler
L = []
for count, msg in enumerate(reversed(deque_handler.deque)):
if n and count >= n or msg.created < start:
break
if timestamps:
L.append((msg.created, msg.levelname, deque_handler.format(msg)))
else:
L.append((msg.levelname, deque_handler.format(msg)))
return L
def start_http_server(
self, routes, dashboard_address, default_port=0, ssl_options=None
):
"""This creates an HTTP Server running on this node"""
self.http_application = RoutingApplication(routes)
# TLS configuration
tls_key = dask.config.get("distributed.scheduler.dashboard.tls.key")
tls_cert = dask.config.get("distributed.scheduler.dashboard.tls.cert")
tls_ca_file = dask.config.get("distributed.scheduler.dashboard.tls.ca-file")
if tls_cert:
ssl_options = ssl.create_default_context(
cafile=tls_ca_file, purpose=ssl.Purpose.CLIENT_AUTH
)
ssl_options.load_cert_chain(tls_cert, keyfile=tls_key)
self.http_server = HTTPServer(self.http_application, ssl_options=ssl_options)
http_addresses = clean_dashboard_address(dashboard_address or default_port)
for http_address in http_addresses:
if http_address["address"] is None:
address = self._start_address
if isinstance(address, (list, tuple)):
address = address[0]
if address:
with suppress(ValueError):
http_address["address"] = get_address_host(address)
change_port = False
retries_left = 3
while True:
try:
if not change_port:
self.http_server.listen(**http_address)
else:
self.http_server.listen(**tlz.merge(http_address, {"port": 0}))
break
except Exception:
change_port = True
retries_left = retries_left - 1
if retries_left < 1:
raise
bound_addresses = get_tcp_server_addresses(self.http_server)
# If more than one address is configured we just use the first here
self.http_server.port = bound_addresses[0][1]
self.services["dashboard"] = self.http_server
# Warn on port changes
for expected, actual in zip(
[a["port"] for a in http_addresses], [b[1] for b in bound_addresses]
):
if expected != actual and expected > 0:
warnings.warn(
f"Port {expected} is already in use.\n"
"Perhaps you already have a cluster running?\n"
f"Hosting the HTTP server on port {actual} instead"
)
|