Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import filecmp | |
| import inspect | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| from collections.abc import Iterable | |
| from importlib import import_module | |
| from types import ModuleType | |
| from typing import TYPE_CHECKING, cast | |
| import click | |
| from dask.utils import tmpfile | |
| from distributed.core import Server | |
| from distributed.utils import import_file | |
| if TYPE_CHECKING: | |
| # This has to be inside this guard to avoid a circular import | |
| from distributed.client import Client | |
| logger = logging.getLogger(__name__) | |
| def validate_preload_argv(ctx, param, value): | |
| """Click option callback providing validation of preload subcommand arguments.""" | |
| if not value and not ctx.params.get("preload", None): | |
| # No preload argv provided and no preload modules specified. | |
| return value | |
| if value and not ctx.params.get("preload", None): | |
| # Report a usage error matching standard click error conventions. | |
| unexpected_args = [v for v in value if v.startswith("-")] | |
| for a in unexpected_args: | |
| raise click.NoSuchOption(a) | |
| raise click.UsageError( | |
| "Got unexpected extra argument%s: (%s)" | |
| % ("s" if len(value) > 1 else "", " ".join(value)) | |
| ) | |
| preload_modules = { | |
| name: _import_module(name) | |
| for name in ctx.params.get("preload") | |
| if not is_webaddress(name) | |
| } | |
| preload_commands = [ | |
| getattr(m, "dask_setup", None) | |
| for m in preload_modules.values() | |
| if isinstance(getattr(m, "dask_setup", None), click.Command) | |
| ] | |
| if len(preload_commands) > 1: | |
| raise click.UsageError( | |
| "Multiple --preload modules with click-configurable setup: %s" | |
| % list(preload_modules.keys()) | |
| ) | |
| if value and not preload_commands: | |
| raise click.UsageError( | |
| "Unknown argument specified: %r Was click-configurable --preload target provided?" | |
| ) | |
| if not preload_commands: | |
| return value | |
| else: | |
| preload_command = preload_commands[0] | |
| ctx = click.Context(preload_command, allow_extra_args=False) | |
| preload_command.parse_args(ctx, list(value)) | |
| return value | |
| def is_webaddress(s: str) -> bool: | |
| return s.startswith(("http://", "https://")) | |
| def _import_module(name: str, file_dir: str | None = None) -> ModuleType: | |
| """Imports module and extract preload interface functions. | |
| Import modules specified by name and extract 'dask_setup' | |
| and 'dask_teardown' if present. | |
| Parameters | |
| ---------- | |
| name : str | |
| Module name, file path, or text of module or script | |
| file_dir : string | |
| Path of a directory where files should be copied | |
| Returns | |
| ------- | |
| Nest dict of names to extracted module interface components if present | |
| in imported module. | |
| """ | |
| if name.endswith(".py"): | |
| # name is a file path | |
| if file_dir is not None: | |
| basename = os.path.basename(name) | |
| copy_dst = os.path.join(file_dir, basename) | |
| if os.path.exists(copy_dst): | |
| if not filecmp.cmp(name, copy_dst): | |
| logger.error("File name collision: %s", basename) | |
| shutil.copy(name, copy_dst) | |
| module = import_file(copy_dst)[0] | |
| else: | |
| module = import_file(name)[0] | |
| elif " " not in name: | |
| # name is a module name | |
| if name not in sys.modules: | |
| import_module(name) | |
| module = sys.modules[name] | |
| else: | |
| # not a name, actually the text of the script | |
| with tmpfile(extension=".py") as fn: | |
| with open(fn, mode="w") as f: | |
| f.write(name) | |
| return _import_module(fn, file_dir=file_dir) | |
| logger.info("Import preload module: %s", name) | |
| return module | |
| def _download_module(url: str) -> ModuleType: | |
| logger.info("Downloading preload at %s", url) | |
| assert is_webaddress(url) | |
| # This is the only place where urrllib3 is used and it is a relatively heavy | |
| # import. Do lazy import to reduce import time | |
| import urllib3 | |
| with urllib3.PoolManager() as http: | |
| response = http.request( | |
| method="GET", | |
| url=url, | |
| retries=urllib3.util.Retry( | |
| status_forcelist=[429, 504, 503, 502], | |
| backoff_factor=0.2, | |
| ), | |
| ) | |
| source = response.data | |
| compiled = compile(source, url, "exec") | |
| module = ModuleType(url) | |
| exec(compiled, module.__dict__) | |
| return module | |
| class Preload: | |
| """ | |
| Manage state for setup/teardown of a preload module | |
| Parameters | |
| ---------- | |
| dask_server: dask.distributed.Server | |
| The Worker or Scheduler | |
| name: str | |
| module name, file name, or web address to load | |
| argv: [str] | |
| List of string arguments passed to click-configurable `dask_setup`. | |
| file_dir: str | |
| Path of a directory where files should be copied | |
| """ | |
| dask_object: Server | Client | |
| name: str | |
| argv: list[str] | |
| file_dir: str | None | |
| module: ModuleType | |
| def __init__( | |
| self, | |
| dask_object: Server | Client, | |
| name: str, | |
| argv: Iterable[str], | |
| file_dir: str | None, | |
| ): | |
| self.dask_object = dask_object | |
| self.name = name | |
| self.argv = list(argv) | |
| self.file_dir = file_dir | |
| logger.info("Creating preload: %s", self.name) | |
| if is_webaddress(name): | |
| self.module = _download_module(name) | |
| else: | |
| self.module = _import_module(name, file_dir) | |
| async def start(self): | |
| """Run when the server finishes its start method""" | |
| dask_setup = getattr(self.module, "dask_setup", None) | |
| if dask_setup: | |
| logger.info("Run preload setup: %s", self.name) | |
| if isinstance(dask_setup, click.Command): | |
| context = dask_setup.make_context( | |
| "dask_setup", self.argv, allow_extra_args=False | |
| ) | |
| result = dask_setup.callback( | |
| self.dask_object, *context.args, **context.params | |
| ) | |
| if inspect.isawaitable(result): | |
| await result | |
| else: | |
| future = dask_setup(self.dask_object) | |
| if inspect.isawaitable(future): | |
| await future | |
| async def teardown(self): | |
| """Run when the server starts its close method""" | |
| dask_teardown = getattr(self.module, "dask_teardown", None) | |
| if dask_teardown: | |
| logger.info("Run preload teardown: %s", self.name) | |
| future = dask_teardown(self.dask_object) | |
| if inspect.isawaitable(future): | |
| await future | |
| def process_preloads( | |
| dask_server: Server | Client, | |
| preload: str | list[str], | |
| preload_argv: list[str] | list[list[str]], | |
| *, | |
| file_dir: str | None = None, | |
| ) -> list[Preload]: | |
| if isinstance(preload, str): | |
| preload = [preload] | |
| if preload_argv and isinstance(preload_argv[0], str): | |
| preload_argv = [cast("list[str]", preload_argv)] * len(preload) | |
| elif not preload_argv: | |
| preload_argv = [cast("list[str]", [])] * len(preload) | |
| if len(preload) != len(preload_argv): | |
| raise ValueError( | |
| "preload and preload_argv have mismatched lengths " | |
| f"{len(preload)} != {len(preload_argv)}" | |
| ) | |
| return [ | |
| Preload(dask_server, p, argv, file_dir) | |
| for p, argv in zip(preload, preload_argv) | |
| ] | |