File size: 7,513 Bytes
d1ed09d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from __future__ import annotations

import filecmp
import inspect
import logging
import os
import shutil
import sys
from collections.abc import Iterable
from importlib import import_module
from types import ModuleType
from typing import TYPE_CHECKING, cast

import click

from dask.utils import tmpfile

from distributed.core import Server
from distributed.utils import import_file

if TYPE_CHECKING:
    # This has to be inside this guard to avoid a circular import
    from distributed.client import Client

logger = logging.getLogger(__name__)


def validate_preload_argv(ctx, param, value):
    """Click option callback providing validation of preload subcommand arguments."""
    if not value and not ctx.params.get("preload", None):
        # No preload argv provided and no preload modules specified.
        return value

    if value and not ctx.params.get("preload", None):
        # Report a usage error matching standard click error conventions.
        unexpected_args = [v for v in value if v.startswith("-")]
        for a in unexpected_args:
            raise click.NoSuchOption(a)
        raise click.UsageError(
            "Got unexpected extra argument%s: (%s)"
            % ("s" if len(value) > 1 else "", " ".join(value))
        )

    preload_modules = {
        name: _import_module(name)
        for name in ctx.params.get("preload")
        if not is_webaddress(name)
    }

    preload_commands = [
        getattr(m, "dask_setup", None)
        for m in preload_modules.values()
        if isinstance(getattr(m, "dask_setup", None), click.Command)
    ]

    if len(preload_commands) > 1:
        raise click.UsageError(
            "Multiple --preload modules with click-configurable setup: %s"
            % list(preload_modules.keys())
        )

    if value and not preload_commands:
        raise click.UsageError(
            "Unknown argument specified: %r Was click-configurable --preload target provided?"
        )
    if not preload_commands:
        return value
    else:
        preload_command = preload_commands[0]

    ctx = click.Context(preload_command, allow_extra_args=False)
    preload_command.parse_args(ctx, list(value))

    return value


def is_webaddress(s: str) -> bool:
    return s.startswith(("http://", "https://"))


def _import_module(name: str, file_dir: str | None = None) -> ModuleType:
    """Imports module and extract preload interface functions.

    Import modules specified by name and extract 'dask_setup'
    and 'dask_teardown' if present.

    Parameters
    ----------
    name : str
        Module name, file path, or text of module or script
    file_dir : string
        Path of a directory where files should be copied

    Returns
    -------
    Nest dict of names to extracted module interface components if present
    in imported module.
    """
    if name.endswith(".py"):
        # name is a file path
        if file_dir is not None:
            basename = os.path.basename(name)
            copy_dst = os.path.join(file_dir, basename)
            if os.path.exists(copy_dst):
                if not filecmp.cmp(name, copy_dst):
                    logger.error("File name collision: %s", basename)
            shutil.copy(name, copy_dst)
            module = import_file(copy_dst)[0]
        else:
            module = import_file(name)[0]

    elif " " not in name:
        # name is a module name
        if name not in sys.modules:
            import_module(name)
        module = sys.modules[name]

    else:
        # not a name, actually the text of the script
        with tmpfile(extension=".py") as fn:
            with open(fn, mode="w") as f:
                f.write(name)
            return _import_module(fn, file_dir=file_dir)

    logger.info("Import preload module: %s", name)
    return module


def _download_module(url: str) -> ModuleType:
    logger.info("Downloading preload at %s", url)
    assert is_webaddress(url)
    # This is the only place where urrllib3 is used and it is a relatively heavy
    # import. Do lazy import to reduce import time
    import urllib3

    with urllib3.PoolManager() as http:
        response = http.request(
            method="GET",
            url=url,
            retries=urllib3.util.Retry(
                status_forcelist=[429, 504, 503, 502],
                backoff_factor=0.2,
            ),
        )

        source = response.data

    compiled = compile(source, url, "exec")
    module = ModuleType(url)
    exec(compiled, module.__dict__)
    return module


class Preload:
    """
    Manage state for setup/teardown of a preload module

    Parameters
    ----------
    dask_server: dask.distributed.Server
        The Worker or Scheduler
    name: str
        module name, file name, or web address to load
    argv: [str]
        List of string arguments passed to click-configurable `dask_setup`.
    file_dir: str
        Path of a directory where files should be copied
    """

    dask_object: Server | Client
    name: str
    argv: list[str]
    file_dir: str | None
    module: ModuleType

    def __init__(
        self,
        dask_object: Server | Client,
        name: str,
        argv: Iterable[str],
        file_dir: str | None,
    ):
        self.dask_object = dask_object
        self.name = name
        self.argv = list(argv)
        self.file_dir = file_dir

        logger.info("Creating preload: %s", self.name)

        if is_webaddress(name):
            self.module = _download_module(name)
        else:
            self.module = _import_module(name, file_dir)

    async def start(self):
        """Run when the server finishes its start method"""
        dask_setup = getattr(self.module, "dask_setup", None)

        if dask_setup:
            logger.info("Run preload setup: %s", self.name)
            if isinstance(dask_setup, click.Command):
                context = dask_setup.make_context(
                    "dask_setup", self.argv, allow_extra_args=False
                )
                result = dask_setup.callback(
                    self.dask_object, *context.args, **context.params
                )
                if inspect.isawaitable(result):
                    await result
            else:
                future = dask_setup(self.dask_object)
                if inspect.isawaitable(future):
                    await future

    async def teardown(self):
        """Run when the server starts its close method"""
        dask_teardown = getattr(self.module, "dask_teardown", None)
        if dask_teardown:
            logger.info("Run preload teardown: %s", self.name)
            future = dask_teardown(self.dask_object)
            if inspect.isawaitable(future):
                await future


def process_preloads(
    dask_server: Server | Client,
    preload: str | list[str],
    preload_argv: list[str] | list[list[str]],
    *,
    file_dir: str | None = None,
) -> list[Preload]:
    if isinstance(preload, str):
        preload = [preload]
    if preload_argv and isinstance(preload_argv[0], str):
        preload_argv = [cast("list[str]", preload_argv)] * len(preload)
    elif not preload_argv:
        preload_argv = [cast("list[str]", [])] * len(preload)
    if len(preload) != len(preload_argv):
        raise ValueError(
            "preload and preload_argv have mismatched lengths "
            f"{len(preload)} != {len(preload_argv)}"
        )

    return [
        Preload(dask_server, p, argv, file_dir)
        for p, argv in zip(preload, preload_argv)
    ]