Jinglong Xiong
first commit
15369ca
'''
Core implementation of the client module, implementing generic communication
patterns with Python in / Python out supporting many (nested) primitives +
special data science types like DataFrames or np.ndarrays, with gRPC + protobuf
as a backing implementation.
'''
import grpc
import io
import json
import socket
import time
from concurrent import futures
from typing import Callable, List, Tuple
import numpy as np
import pandas as pd
import polars as pl
import pyarrow
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2 as kaggle_evaluation_proto
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2_grpc as kaggle_evaluation_grpc
_SERVICE_CONFIG = {
# Service config proto: https://github.com/grpc/grpc-proto/blob/ec886024c2f7b7f597ba89d5b7d60c3f94627b17/grpc/service_config/service_config.proto#L377
'methodConfig': [
{
'name': [{}], # Applies to all methods
# See retry policy docs: https://grpc.io/docs/guides/retry/
'retryPolicy': {
'maxAttempts': 5,
'initialBackoff': '0.1s',
'maxBackoff': '1s',
'backoffMultiplier': 1, # Ensure relatively rapid feedback in the event of a crash
'retryableStatusCodes': ['UNAVAILABLE'],
},
}
]
}
_GRPC_PORT = 50051
_GRPC_CHANNEL_OPTIONS = [
# -1 for unlimited message send/receive size
# https://github.com/grpc/grpc/blob/v1.64.x/include/grpc/impl/channel_arg_names.h#L39
('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1),
# https://github.com/grpc/grpc/blob/master/doc/keepalive.md
('grpc.keepalive_time_ms', 60_000), # Time between heartbeat pings
('grpc.keepalive_timeout_ms', 5_000), # Time allowed to respond to pings
('grpc.http2.max_pings_without_data', 0), # Remove another cap on pings
('grpc.keepalive_permit_without_calls', 1), # Allow heartbeat pings at any time
('grpc.http2.min_ping_interval_without_data_ms', 1_000),
('grpc.service_config', json.dumps(_SERVICE_CONFIG)),
]
DEFAULT_DEADLINE_SECONDS = 60 * 60
_RETRY_SLEEP_SECONDS = 1
# Enforce a relatively strict server startup time so users can get feedback quickly if they're not
# configuring KaggleEvaluation correctly. We really don't want notebooks timing out after nine hours
# somebody forgot to start their inference_server. Slow steps like loading models
# can happen during the first inference call if necessary.
STARTUP_LIMIT_SECONDS = 60 * 15
### Utils shared by client and server for data transfer
# pl.Enum is currently unstable, but we should eventually consider supporting it.
# https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Enum.html#polars.datatypes.Enum
_POLARS_TYPE_DENYLIST = set([pl.Enum, pl.Object, pl.Unknown])
def _serialize(data) -> kaggle_evaluation_proto.Payload:
'''Maps input data of one of several allow-listed types to a protobuf message to be sent over gRPC.
Args:
data: The input data to be mapped. Any of the types listed below are accepted.
Returns:
The Payload protobuf message.
Raises:
TypeError if data is of an unsupported type.
'''
# Python primitives and Numpy scalars
if isinstance(data, np.generic):
# Numpy functions that return a single number return numpy scalars instead of python primitives.
# In some cases this difference matters: https://numpy.org/devdocs/release/2.0.0-notes.html#representation-of-numpy-scalars-changed
# Ex: np.mean(1,2) yields np.float64(1.5) instead of 1.5.
# Check for numpy scalars first since most of them also inherit from python primitives.
# For example, `np.float64(1.5)` is an instance of `float` among many other things.
# https://numpy.org/doc/stable/reference/arrays.scalars.html
assert data.shape == () # Additional validation that the np.generic type remains solely for scalars
assert isinstance(data, np.number) or isinstance(data, np.bool_) # No support for bytes, strings, objects, etc
buffer = io.BytesIO()
np.save(buffer, data, allow_pickle=False)
return kaggle_evaluation_proto.Payload(numpy_scalar_value=buffer.getvalue())
elif isinstance(data, str):
return kaggle_evaluation_proto.Payload(str_value=data)
elif isinstance(data, bool): # bool is a subclass of int, so check that first
return kaggle_evaluation_proto.Payload(bool_value=data)
elif isinstance(data, int):
return kaggle_evaluation_proto.Payload(int_value=data)
elif isinstance(data, float):
return kaggle_evaluation_proto.Payload(float_value=data)
elif data is None:
return kaggle_evaluation_proto.Payload(none_value=True)
# Iterables for nested types
if isinstance(data, list):
return kaggle_evaluation_proto.Payload(list_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data)))
elif isinstance(data, tuple):
return kaggle_evaluation_proto.Payload(tuple_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data)))
elif isinstance(data, dict):
serialized_dict = {}
for key, value in data.items():
if not isinstance(key, str):
raise TypeError(f'KaggleEvaluation only supports dicts with keys of type str, found {type(key)}.')
serialized_dict[key] = _serialize(value)
return kaggle_evaluation_proto.Payload(dict_value=kaggle_evaluation_proto.PayloadMap(payload_map=serialized_dict))
# Allowlisted special types
if isinstance(data, pd.DataFrame):
buffer = io.BytesIO()
data.to_parquet(buffer, index=False, compression='lz4')
return kaggle_evaluation_proto.Payload(pandas_dataframe_value=buffer.getvalue())
elif isinstance(data, pl.DataFrame):
data_types = set(i.base_type() for i in data.dtypes)
banned_types = _POLARS_TYPE_DENYLIST.intersection(data_types)
if len(banned_types) > 0:
raise TypeError(f'Unsupported Polars data type(s): {banned_types}')
table = data.to_arrow()
buffer = io.BytesIO()
with pyarrow.ipc.new_stream(buffer, table.schema, options=pyarrow.ipc.IpcWriteOptions(compression='lz4')) as writer:
writer.write_table(table)
return kaggle_evaluation_proto.Payload(polars_dataframe_value=buffer.getvalue())
elif isinstance(data, pd.Series):
buffer = io.BytesIO()
# Can't serialize a pd.Series directly to parquet, must use intermediate DataFrame
pd.DataFrame(data).to_parquet(buffer, index=False, compression='lz4')
return kaggle_evaluation_proto.Payload(pandas_series_value=buffer.getvalue())
elif isinstance(data, pl.Series):
buffer = io.BytesIO()
# Can't serialize a pl.Series directly to parquet, must use intermediate DataFrame
pl.DataFrame(data).write_parquet(buffer, compression='lz4', statistics=False)
return kaggle_evaluation_proto.Payload(polars_series_value=buffer.getvalue())
elif isinstance(data, np.ndarray):
buffer = io.BytesIO()
np.save(buffer, data, allow_pickle=False)
return kaggle_evaluation_proto.Payload(numpy_array_value=buffer.getvalue())
elif isinstance(data, io.BytesIO):
return kaggle_evaluation_proto.Payload(bytes_io_value=data.getvalue())
raise TypeError(f'Type {type(data)} not supported for KaggleEvaluation.')
def _deserialize(payload: kaggle_evaluation_proto.Payload):
'''Maps a Payload protobuf message to a value of whichever type was set on the message.
Args:
payload: The message to be mapped.
Returns:
A value of one of several allow-listed types.
Raises:
TypeError if an unexpected value data type is found.
'''
# Primitives
if payload.WhichOneof('value') == 'str_value':
return payload.str_value
elif payload.WhichOneof('value') == 'bool_value':
return payload.bool_value
elif payload.WhichOneof('value') == 'int_value':
return payload.int_value
elif payload.WhichOneof('value') == 'float_value':
return payload.float_value
elif payload.WhichOneof('value') == 'none_value':
return None
# Iterables for nested types
elif payload.WhichOneof('value') == 'list_value':
return list(map(_deserialize, payload.list_value.payloads))
elif payload.WhichOneof('value') == 'tuple_value':
return tuple(map(_deserialize, payload.tuple_value.payloads))
elif payload.WhichOneof('value') == 'dict_value':
return {key: _deserialize(value) for key, value in payload.dict_value.payload_map.items()}
# Allowlisted special types
elif payload.WhichOneof('value') == 'pandas_dataframe_value':
return pd.read_parquet(io.BytesIO(payload.pandas_dataframe_value))
elif payload.WhichOneof('value') == 'polars_dataframe_value':
with pyarrow.ipc.open_stream(payload.polars_dataframe_value) as reader:
table = reader.read_all()
return pl.from_arrow(table)
elif payload.WhichOneof('value') == 'pandas_series_value':
# Pandas will still read a single column csv as a DataFrame.
df = pd.read_parquet(io.BytesIO(payload.pandas_series_value))
return pd.Series(df[df.columns[0]])
elif payload.WhichOneof('value') == 'polars_series_value':
return pl.Series(pl.read_parquet(io.BytesIO(payload.polars_series_value)))
elif payload.WhichOneof('value') == 'numpy_array_value':
return np.load(io.BytesIO(payload.numpy_array_value), allow_pickle=False)
elif payload.WhichOneof('value') == 'numpy_scalar_value':
data = np.load(io.BytesIO(payload.numpy_scalar_value), allow_pickle=False)
# As of Numpy 2.0.2, np.load for a numpy scalar yields a dimensionless array instead of a scalar
data = data.dtype.type(data) # Restore the expected numpy scalar type.
assert data.shape == () # Additional validation that the np.generic type remains solely for scalars
assert isinstance(data, np.number) or isinstance(data, np.bool_) # No support for bytes, strings, objects, etc
return data
elif payload.WhichOneof('value') == 'bytes_io_value':
return io.BytesIO(payload.bytes_io_value)
raise TypeError(f'Found unknown Payload case {payload.WhichOneof("value")}')
### Client code
class Client():
'''
Class which allows callers to make KaggleEvaluation requests.
'''
def __init__(self, channel_address: str='localhost'):
self.channel_address = channel_address
self.channel = grpc.insecure_channel(f'{channel_address}:{_GRPC_PORT}', options=_GRPC_CHANNEL_OPTIONS)
self._made_first_connection = False
self.endpoint_deadline_seconds = DEFAULT_DEADLINE_SECONDS
self.stub = kaggle_evaluation_grpc.KaggleEvaluationServiceStub(self.channel)
def _send_with_deadline(self, request):
''' Sends a message to the server while also:
- Throwing an error as soon as the inference_server container has been shut down.
- Setting a deadline of STARTUP_LIMIT_SECONDS for the inference_server to startup.
'''
if self._made_first_connection:
return self.stub.Send(request, wait_for_ready=False, timeout=self.endpoint_deadline_seconds)
first_call_time = time.time()
# Allow time for the server to start as long as its container is running
while time.time() - first_call_time < STARTUP_LIMIT_SECONDS:
try:
response = self.stub.Send(request, wait_for_ready=False)
self._made_first_connection = True
break
except grpc._channel._InactiveRpcError as err:
if 'StatusCode.UNAVAILABLE' not in str(err):
raise err
# Confirm the inference_server container is still alive & it's worth waiting on the server.
# If the inference_server container is no longer running this will throw a socket.gaierror.
socket.gethostbyname(self.channel_address)
time.sleep(_RETRY_SLEEP_SECONDS)
if not self._made_first_connection:
raise RuntimeError(f'Failed to connect to server after waiting {STARTUP_LIMIT_SECONDS} seconds')
return response
def serialize_request(self, name: str, *args, **kwargs) -> kaggle_evaluation_proto.KaggleEvaluationRequest:
''' Serialize a single request. Exists as a separate function from `send`
to enable gateway concurrency for some competitions.
'''
already_serialized = (len(args) == 1) and isinstance(args[0], kaggle_evaluation_proto.KaggleEvaluationRequest)
if already_serialized:
return args[0] # args is a tuple of length 1 containing the request
return kaggle_evaluation_proto.KaggleEvaluationRequest(
name=name,
args=map(_serialize, args),
kwargs={key: _serialize(value) for key, value in kwargs.items()}
)
def send(self, name: str, *args, **kwargs):
'''Sends a single KaggleEvaluation request.
Args:
name: The endpoint name for the request.
*args: Variable-length/type arguments to be supplied on the request.
**kwargs: Key-value arguments to be supplied on the request.
Returns:
The response, which is of one of several allow-listed data types.
'''
request = self.serialize_request(name, *args, **kwargs)
response = self._send_with_deadline(request)
return _deserialize(response.payload)
def close(self):
self.channel.close()
### Server code
class KaggleEvaluationServiceServicer(kaggle_evaluation_grpc.KaggleEvaluationServiceServicer):
'''
Class which allows serving responses to KaggleEvaluation requests. The inference_server will run this service to listen for and respond
to requests from the Gateway. The Gateway may also listen for requests from the inference_server in some cases.
'''
def __init__(self, listeners: List[callable]):
self.listeners_map = dict((func.__name__, func) for func in listeners)
# pylint: disable=unused-argument
def Send(self, request: kaggle_evaluation_proto.KaggleEvaluationRequest, context: grpc.ServicerContext) -> kaggle_evaluation_proto.KaggleEvaluationResponse:
'''Handler for gRPC requests that deserializes arguments, calls a user-registered function for handling the
requested endpoint, then serializes and returns the response.
Args:
request: The KaggleEvaluationRequest protobuf message.
context: (Unused) gRPC context.
Returns:
The KaggleEvaluationResponse protobuf message.
Raises:
NotImplementedError if the caller has not registered a handler for the requested endpoint.
'''
if request.name not in self.listeners_map:
raise NotImplementedError(f'No listener for {request.name} was registered.')
args = map(_deserialize, request.args)
kwargs = {key: _deserialize(value) for key, value in request.kwargs.items()}
response_function = self.listeners_map[request.name]
response_payload = _serialize(response_function(*args, **kwargs))
return kaggle_evaluation_proto.KaggleEvaluationResponse(payload=response_payload)
def define_server(*endpoint_listeners: Tuple[Callable]) -> grpc.server:
'''Registers the endpoints that the container is able to respond to, then starts a server which listens for
those endpoints. The endpoints that need to be implemented will depend on the specific competition.
Args:
endpoint_listeners: Tuple of functions that define how requests to the endpoint of the function name should be
handled.
Returns:
The gRPC server object, which has been started. It should be stopped at exit time.
Raises:
ValueError if parameter values are invalid.
'''
if not endpoint_listeners:
raise ValueError('Must pass at least one endpoint listener, e.g. `predict`')
for func in endpoint_listeners:
if not isinstance(func, Callable):
raise ValueError('Endpoint listeners passed to `serve` must be functions')
if func.__name__ == '<lambda>':
raise ValueError('Functions passed as endpoint listeners must be named')
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1), options=_GRPC_CHANNEL_OPTIONS)
kaggle_evaluation_grpc.add_KaggleEvaluationServiceServicer_to_server(KaggleEvaluationServiceServicer(endpoint_listeners), server)
server.add_insecure_port(f'[::]:{_GRPC_PORT}')
return server