File size: 11,548 Bytes
cf2a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for working with python gRPC stubs."""


import enum
import functools
import random
import threading
import time

import grpc

from tensorboard import version
from tensorboard.util import tb_logging

logger = tb_logging.get_logger()

# Default RPC timeout.
_GRPC_DEFAULT_TIMEOUT_SECS = 30

# Max number of times to attempt an RPC, retrying on transient failures.
_GRPC_RETRY_MAX_ATTEMPTS = 5

# Parameters to control the exponential backoff behavior.
_GRPC_RETRY_EXPONENTIAL_BASE = 2
_GRPC_RETRY_JITTER_FACTOR_MIN = 1.1
_GRPC_RETRY_JITTER_FACTOR_MAX = 1.5

# Status codes from gRPC for which it's reasonable to retry the RPC.
_GRPC_RETRYABLE_STATUS_CODES = frozenset(
    [
        grpc.StatusCode.ABORTED,
        grpc.StatusCode.DEADLINE_EXCEEDED,
        grpc.StatusCode.RESOURCE_EXHAUSTED,
        grpc.StatusCode.UNAVAILABLE,
    ]
)

# gRPC metadata key whose value contains the client version.
_VERSION_METADATA_KEY = "tensorboard-version"


class AsyncCallFuture:
    """Encapsulates the future value of a retriable async gRPC request.

    Abstracts over the set of futures returned by a set of gRPC calls
    comprising a single logical gRPC request with retries.  Communicates
    to the caller the result or exception resulting from the request.

    Args:
      completion_event: The constructor should provide a `threding.Event` which
        will be used to communicate when the set of gRPC requests is complete.
    """

    def __init__(self, completion_event):
        self._active_grpc_future = None
        self._active_grpc_future_lock = threading.Lock()
        self._completion_event = completion_event

    def _set_active_future(self, grpc_future):
        if grpc_future is None:
            raise RuntimeError(
                "_set_active_future invoked with grpc_future=None."
            )
        with self._active_grpc_future_lock:
            self._active_grpc_future = grpc_future

    def result(self, timeout):
        """Analogous to `grpc.Future.result`. Returns the value or exception.

        This method will wait until the full set of gRPC requests is complete
        and then act as `grpc.Future.result` for the single gRPC invocation
        corresponding to the first successful call or final failure, as
        appropriate.

        Args:
          timeout: How long to wait in seconds before giving up and raising.

        Returns:
          The result of the future corresponding to the single gRPC
          corresponding to the successful call.

        Raises:
          * `grpc.FutureTimeoutError` if timeout seconds elapse before the gRPC
          calls could complete, including waits and retries.
          * The exception corresponding to the last non-retryable gRPC request
          in the case that a successful gRPC request was not made.
        """
        if not self._completion_event.wait(timeout):
            raise grpc.FutureTimeoutError(
                f"AsyncCallFuture timed out after {timeout} seconds"
            )
        with self._active_grpc_future_lock:
            if self._active_grpc_future is None:
                raise RuntimeError("AsyncFuture never had an active future set")
            return self._active_grpc_future.result()


def async_call_with_retries(api_method, request, clock=None):
    """Initiate an asynchronous call to a gRPC stub, with retry logic.

    This is similar to the `async_call` API, except that the call is handled
    asynchronously, and the completion may be handled by another thread. The
    caller must provide a `done_callback` argument which will handle the
    result or exception rising from the gRPC completion.

    Retries are handled with jittered exponential backoff to spread out failures
    due to request spikes.

    This only supports unary-unary RPCs: i.e., no streaming on either end.

    Args:
      api_method: Callable for the API method to invoke.
      request: Request protocol buffer to pass to the API method.
      clock: an interface object supporting `time()` and `sleep()` methods
        like the standard `time` module; if not passed, uses the normal module.

    Returns:
      An `AsyncCallFuture` which will encapsulate the `grpc.Future`
      corresponding to the gRPC call which either completes successfully or
      represents the final try.
    """
    if clock is None:
        clock = time
    logger.debug("Async RPC call %s with request: %r", api_method, request)

    completion_event = threading.Event()
    async_future = AsyncCallFuture(completion_event)

    def async_call(handler):
        """Invokes the gRPC future and orchestrates it via the AsyncCallFuture."""
        future = api_method.future(
            request,
            timeout=_GRPC_DEFAULT_TIMEOUT_SECS,
            metadata=version_metadata(),
        )
        # Ensure we set the active future before invoking the done callback, to
        # avoid the case where the done callback completes immediately and
        # triggers completion event while async_future still holds the old
        # future.
        async_future._set_active_future(future)
        future.add_done_callback(handler)

    # retry_handler is the continuation of the `async_call`.  It should:
    #   * If the grpc call succeeds: trigger the `completion_event`.
    #   * If there are no more retries: trigger the `completion_event`.
    #   * Otherwise, invoke a new async_call with the same
    #     retry_handler.
    def retry_handler(future, num_attempts):
        e = future.exception()
        if e is None:
            completion_event.set()
            return
        else:
            logger.info("RPC call %s got error %s", api_method, e)
            # If unable to retry, proceed to completion.
            if e.code() not in _GRPC_RETRYABLE_STATUS_CODES:
                completion_event.set()
                return
            if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS:
                completion_event.set()
                return
            # If able to retry, wait then do so.
            backoff_secs = _compute_backoff_seconds(num_attempts)
            clock.sleep(backoff_secs)
            async_call(
                functools.partial(retry_handler, num_attempts=num_attempts + 1)
            )

    async_call(functools.partial(retry_handler, num_attempts=1))
    return async_future


def _compute_backoff_seconds(num_attempts):
    """Compute appropriate wait time between RPC attempts."""
    jitter_factor = random.uniform(
        _GRPC_RETRY_JITTER_FACTOR_MIN, _GRPC_RETRY_JITTER_FACTOR_MAX
    )
    backoff_secs = (_GRPC_RETRY_EXPONENTIAL_BASE**num_attempts) * jitter_factor
    return backoff_secs


def call_with_retries(api_method, request, clock=None):
    """Call a gRPC stub API method, with automatic retry logic.

    This only supports unary-unary RPCs: i.e., no streaming on either end.
    Streamed RPCs will generally need application-level pagination support,
    because after a gRPC error one must retry the entire request; there is no
    "retry-resume" functionality.

    Retries are handled with jittered exponential backoff to spread out failures
    due to request spikes.

    Args:
      api_method: Callable for the API method to invoke.
      request: Request protocol buffer to pass to the API method.
      clock: an interface object supporting `time()` and `sleep()` methods
        like the standard `time` module; if not passed, uses the normal module.

    Returns:
      Response protocol buffer returned by the API method.

    Raises:
      grpc.RpcError: if a non-retryable error is returned, or if all retry
        attempts have been exhausted.
    """
    if clock is None:
        clock = time
    # We can't actually use api_method.__name__ because it's not a real method,
    # it's a special gRPC callable instance that doesn't expose the method name.
    rpc_name = request.__class__.__name__.replace("Request", "")
    logger.debug("RPC call %s with request: %r", rpc_name, request)
    num_attempts = 0
    while True:
        num_attempts += 1
        try:
            return api_method(
                request,
                timeout=_GRPC_DEFAULT_TIMEOUT_SECS,
                metadata=version_metadata(),
            )
        except grpc.RpcError as e:
            logger.info("RPC call %s got error %s", rpc_name, e)
            if e.code() not in _GRPC_RETRYABLE_STATUS_CODES:
                raise
            if num_attempts >= _GRPC_RETRY_MAX_ATTEMPTS:
                raise
        backoff_secs = _compute_backoff_seconds(num_attempts)
        logger.info(
            "RPC call %s attempted %d times, retrying in %.1f seconds",
            rpc_name,
            num_attempts,
            backoff_secs,
        )
        clock.sleep(backoff_secs)


def version_metadata():
    """Creates gRPC invocation metadata encoding the TensorBoard version.

    Usage: `stub.MyRpc(request, metadata=version_metadata())`.

    Returns:
      A tuple of key-value pairs (themselves 2-tuples) to be passed as the
      `metadata` kwarg to gRPC stub API methods.
    """
    return ((_VERSION_METADATA_KEY, version.VERSION),)


def extract_version(metadata):
    """Extracts version from invocation metadata.

    The argument should be the result of a prior call to `metadata` or the
    result of combining such a result with other metadata.

    Returns:
      The TensorBoard version listed in this metadata, or `None` if none
      is listed.
    """
    return dict(metadata).get(_VERSION_METADATA_KEY)


@enum.unique
class ChannelCredsType(enum.Enum):
    LOCAL = "local"
    SSL = "ssl"
    SSL_DEV = "ssl_dev"

    def channel_config(self):
        """Create channel credentials and options.

        Returns:
          A tuple `(channel_creds, channel_options)`, where `channel_creds`
          is a `grpc.ChannelCredentials` and `channel_options` is a
          (potentially empty) list of `(key, value)` tuples. Both results
          may be passed to `grpc.secure_channel`.
        """

        options = []
        if self == ChannelCredsType.LOCAL:
            creds = grpc.local_channel_credentials()
        elif self == ChannelCredsType.SSL:
            creds = grpc.ssl_channel_credentials()
        elif self == ChannelCredsType.SSL_DEV:
            # Configure the dev cert to use by passing the environment variable
            # GRPC_DEFAULT_SSL_ROOTS_FILE_PATH=path/to/cert.crt
            creds = grpc.ssl_channel_credentials()
            options.append(("grpc.ssl_target_name_override", "localhost"))
        else:
            raise AssertionError("unhandled ChannelCredsType: %r" % self)
        return (creds, options)

    @classmethod
    def choices(cls):
        return cls.__members__.values()

    def __str__(self):
        # Use user-facing string, because this is shown for flag choices.
        return self.value