Spaces:
Paused
Paused
''' | |
Core implementation of the client module, implementing generic communication | |
patterns with Python in / Python out supporting many (nested) primitives + | |
special data science types like DataFrames or np.ndarrays, with gRPC + protobuf | |
as a backing implementation. | |
''' | |
import grpc | |
import io | |
import json | |
import socket | |
import time | |
from concurrent import futures | |
from typing import Callable, List, Tuple | |
import numpy as np | |
import pandas as pd | |
import polars as pl | |
import pyarrow | |
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2 as kaggle_evaluation_proto | |
import kaggle_evaluation.core.generated.kaggle_evaluation_pb2_grpc as kaggle_evaluation_grpc | |
_SERVICE_CONFIG = { | |
# Service config proto: https://github.com/grpc/grpc-proto/blob/ec886024c2f7b7f597ba89d5b7d60c3f94627b17/grpc/service_config/service_config.proto#L377 | |
'methodConfig': [ | |
{ | |
'name': [{}], # Applies to all methods | |
# See retry policy docs: https://grpc.io/docs/guides/retry/ | |
'retryPolicy': { | |
'maxAttempts': 5, | |
'initialBackoff': '0.1s', | |
'maxBackoff': '1s', | |
'backoffMultiplier': 1, # Ensure relatively rapid feedback in the event of a crash | |
'retryableStatusCodes': ['UNAVAILABLE'], | |
}, | |
} | |
] | |
} | |
_GRPC_PORT = 50051 | |
_GRPC_CHANNEL_OPTIONS = [ | |
# -1 for unlimited message send/receive size | |
# https://github.com/grpc/grpc/blob/v1.64.x/include/grpc/impl/channel_arg_names.h#L39 | |
('grpc.max_send_message_length', -1), | |
('grpc.max_receive_message_length', -1), | |
# https://github.com/grpc/grpc/blob/master/doc/keepalive.md | |
('grpc.keepalive_time_ms', 60_000), # Time between heartbeat pings | |
('grpc.keepalive_timeout_ms', 5_000), # Time allowed to respond to pings | |
('grpc.http2.max_pings_without_data', 0), # Remove another cap on pings | |
('grpc.keepalive_permit_without_calls', 1), # Allow heartbeat pings at any time | |
('grpc.http2.min_ping_interval_without_data_ms', 1_000), | |
('grpc.service_config', json.dumps(_SERVICE_CONFIG)), | |
] | |
DEFAULT_DEADLINE_SECONDS = 60 * 60 | |
_RETRY_SLEEP_SECONDS = 1 | |
# Enforce a relatively strict server startup time so users can get feedback quickly if they're not | |
# configuring KaggleEvaluation correctly. We really don't want notebooks timing out after nine hours | |
# somebody forgot to start their inference_server. Slow steps like loading models | |
# can happen during the first inference call if necessary. | |
STARTUP_LIMIT_SECONDS = 60 * 15 | |
### Utils shared by client and server for data transfer | |
# pl.Enum is currently unstable, but we should eventually consider supporting it. | |
# https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Enum.html#polars.datatypes.Enum | |
_POLARS_TYPE_DENYLIST = set([pl.Enum, pl.Object, pl.Unknown]) | |
def _serialize(data) -> kaggle_evaluation_proto.Payload: | |
'''Maps input data of one of several allow-listed types to a protobuf message to be sent over gRPC. | |
Args: | |
data: The input data to be mapped. Any of the types listed below are accepted. | |
Returns: | |
The Payload protobuf message. | |
Raises: | |
TypeError if data is of an unsupported type. | |
''' | |
# Python primitives and Numpy scalars | |
if isinstance(data, np.generic): | |
# Numpy functions that return a single number return numpy scalars instead of python primitives. | |
# In some cases this difference matters: https://numpy.org/devdocs/release/2.0.0-notes.html#representation-of-numpy-scalars-changed | |
# Ex: np.mean(1,2) yields np.float64(1.5) instead of 1.5. | |
# Check for numpy scalars first since most of them also inherit from python primitives. | |
# For example, `np.float64(1.5)` is an instance of `float` among many other things. | |
# https://numpy.org/doc/stable/reference/arrays.scalars.html | |
assert data.shape == () # Additional validation that the np.generic type remains solely for scalars | |
assert isinstance(data, np.number) or isinstance(data, np.bool_) # No support for bytes, strings, objects, etc | |
buffer = io.BytesIO() | |
np.save(buffer, data, allow_pickle=False) | |
return kaggle_evaluation_proto.Payload(numpy_scalar_value=buffer.getvalue()) | |
elif isinstance(data, str): | |
return kaggle_evaluation_proto.Payload(str_value=data) | |
elif isinstance(data, bool): # bool is a subclass of int, so check that first | |
return kaggle_evaluation_proto.Payload(bool_value=data) | |
elif isinstance(data, int): | |
return kaggle_evaluation_proto.Payload(int_value=data) | |
elif isinstance(data, float): | |
return kaggle_evaluation_proto.Payload(float_value=data) | |
elif data is None: | |
return kaggle_evaluation_proto.Payload(none_value=True) | |
# Iterables for nested types | |
if isinstance(data, list): | |
return kaggle_evaluation_proto.Payload(list_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data))) | |
elif isinstance(data, tuple): | |
return kaggle_evaluation_proto.Payload(tuple_value=kaggle_evaluation_proto.PayloadList(payloads=map(_serialize, data))) | |
elif isinstance(data, dict): | |
serialized_dict = {} | |
for key, value in data.items(): | |
if not isinstance(key, str): | |
raise TypeError(f'KaggleEvaluation only supports dicts with keys of type str, found {type(key)}.') | |
serialized_dict[key] = _serialize(value) | |
return kaggle_evaluation_proto.Payload(dict_value=kaggle_evaluation_proto.PayloadMap(payload_map=serialized_dict)) | |
# Allowlisted special types | |
if isinstance(data, pd.DataFrame): | |
buffer = io.BytesIO() | |
data.to_parquet(buffer, index=False, compression='lz4') | |
return kaggle_evaluation_proto.Payload(pandas_dataframe_value=buffer.getvalue()) | |
elif isinstance(data, pl.DataFrame): | |
data_types = set(i.base_type() for i in data.dtypes) | |
banned_types = _POLARS_TYPE_DENYLIST.intersection(data_types) | |
if len(banned_types) > 0: | |
raise TypeError(f'Unsupported Polars data type(s): {banned_types}') | |
table = data.to_arrow() | |
buffer = io.BytesIO() | |
with pyarrow.ipc.new_stream(buffer, table.schema, options=pyarrow.ipc.IpcWriteOptions(compression='lz4')) as writer: | |
writer.write_table(table) | |
return kaggle_evaluation_proto.Payload(polars_dataframe_value=buffer.getvalue()) | |
elif isinstance(data, pd.Series): | |
buffer = io.BytesIO() | |
# Can't serialize a pd.Series directly to parquet, must use intermediate DataFrame | |
pd.DataFrame(data).to_parquet(buffer, index=False, compression='lz4') | |
return kaggle_evaluation_proto.Payload(pandas_series_value=buffer.getvalue()) | |
elif isinstance(data, pl.Series): | |
buffer = io.BytesIO() | |
# Can't serialize a pl.Series directly to parquet, must use intermediate DataFrame | |
pl.DataFrame(data).write_parquet(buffer, compression='lz4', statistics=False) | |
return kaggle_evaluation_proto.Payload(polars_series_value=buffer.getvalue()) | |
elif isinstance(data, np.ndarray): | |
buffer = io.BytesIO() | |
np.save(buffer, data, allow_pickle=False) | |
return kaggle_evaluation_proto.Payload(numpy_array_value=buffer.getvalue()) | |
elif isinstance(data, io.BytesIO): | |
return kaggle_evaluation_proto.Payload(bytes_io_value=data.getvalue()) | |
raise TypeError(f'Type {type(data)} not supported for KaggleEvaluation.') | |
def _deserialize(payload: kaggle_evaluation_proto.Payload): | |
'''Maps a Payload protobuf message to a value of whichever type was set on the message. | |
Args: | |
payload: The message to be mapped. | |
Returns: | |
A value of one of several allow-listed types. | |
Raises: | |
TypeError if an unexpected value data type is found. | |
''' | |
# Primitives | |
if payload.WhichOneof('value') == 'str_value': | |
return payload.str_value | |
elif payload.WhichOneof('value') == 'bool_value': | |
return payload.bool_value | |
elif payload.WhichOneof('value') == 'int_value': | |
return payload.int_value | |
elif payload.WhichOneof('value') == 'float_value': | |
return payload.float_value | |
elif payload.WhichOneof('value') == 'none_value': | |
return None | |
# Iterables for nested types | |
elif payload.WhichOneof('value') == 'list_value': | |
return list(map(_deserialize, payload.list_value.payloads)) | |
elif payload.WhichOneof('value') == 'tuple_value': | |
return tuple(map(_deserialize, payload.tuple_value.payloads)) | |
elif payload.WhichOneof('value') == 'dict_value': | |
return {key: _deserialize(value) for key, value in payload.dict_value.payload_map.items()} | |
# Allowlisted special types | |
elif payload.WhichOneof('value') == 'pandas_dataframe_value': | |
return pd.read_parquet(io.BytesIO(payload.pandas_dataframe_value)) | |
elif payload.WhichOneof('value') == 'polars_dataframe_value': | |
with pyarrow.ipc.open_stream(payload.polars_dataframe_value) as reader: | |
table = reader.read_all() | |
return pl.from_arrow(table) | |
elif payload.WhichOneof('value') == 'pandas_series_value': | |
# Pandas will still read a single column csv as a DataFrame. | |
df = pd.read_parquet(io.BytesIO(payload.pandas_series_value)) | |
return pd.Series(df[df.columns[0]]) | |
elif payload.WhichOneof('value') == 'polars_series_value': | |
return pl.Series(pl.read_parquet(io.BytesIO(payload.polars_series_value))) | |
elif payload.WhichOneof('value') == 'numpy_array_value': | |
return np.load(io.BytesIO(payload.numpy_array_value), allow_pickle=False) | |
elif payload.WhichOneof('value') == 'numpy_scalar_value': | |
data = np.load(io.BytesIO(payload.numpy_scalar_value), allow_pickle=False) | |
# As of Numpy 2.0.2, np.load for a numpy scalar yields a dimensionless array instead of a scalar | |
data = data.dtype.type(data) # Restore the expected numpy scalar type. | |
assert data.shape == () # Additional validation that the np.generic type remains solely for scalars | |
assert isinstance(data, np.number) or isinstance(data, np.bool_) # No support for bytes, strings, objects, etc | |
return data | |
elif payload.WhichOneof('value') == 'bytes_io_value': | |
return io.BytesIO(payload.bytes_io_value) | |
raise TypeError(f'Found unknown Payload case {payload.WhichOneof("value")}') | |
### Client code | |
class Client(): | |
''' | |
Class which allows callers to make KaggleEvaluation requests. | |
''' | |
def __init__(self, channel_address: str='localhost'): | |
self.channel_address = channel_address | |
self.channel = grpc.insecure_channel(f'{channel_address}:{_GRPC_PORT}', options=_GRPC_CHANNEL_OPTIONS) | |
self._made_first_connection = False | |
self.endpoint_deadline_seconds = DEFAULT_DEADLINE_SECONDS | |
self.stub = kaggle_evaluation_grpc.KaggleEvaluationServiceStub(self.channel) | |
def _send_with_deadline(self, request): | |
''' Sends a message to the server while also: | |
- Throwing an error as soon as the inference_server container has been shut down. | |
- Setting a deadline of STARTUP_LIMIT_SECONDS for the inference_server to startup. | |
''' | |
if self._made_first_connection: | |
return self.stub.Send(request, wait_for_ready=False, timeout=self.endpoint_deadline_seconds) | |
first_call_time = time.time() | |
# Allow time for the server to start as long as its container is running | |
while time.time() - first_call_time < STARTUP_LIMIT_SECONDS: | |
try: | |
response = self.stub.Send(request, wait_for_ready=False) | |
self._made_first_connection = True | |
break | |
except grpc._channel._InactiveRpcError as err: | |
if 'StatusCode.UNAVAILABLE' not in str(err): | |
raise err | |
# Confirm the inference_server container is still alive & it's worth waiting on the server. | |
# If the inference_server container is no longer running this will throw a socket.gaierror. | |
socket.gethostbyname(self.channel_address) | |
time.sleep(_RETRY_SLEEP_SECONDS) | |
if not self._made_first_connection: | |
raise RuntimeError(f'Failed to connect to server after waiting {STARTUP_LIMIT_SECONDS} seconds') | |
return response | |
def serialize_request(self, name: str, *args, **kwargs) -> kaggle_evaluation_proto.KaggleEvaluationRequest: | |
''' Serialize a single request. Exists as a separate function from `send` | |
to enable gateway concurrency for some competitions. | |
''' | |
already_serialized = (len(args) == 1) and isinstance(args[0], kaggle_evaluation_proto.KaggleEvaluationRequest) | |
if already_serialized: | |
return args[0] # args is a tuple of length 1 containing the request | |
return kaggle_evaluation_proto.KaggleEvaluationRequest( | |
name=name, | |
args=map(_serialize, args), | |
kwargs={key: _serialize(value) for key, value in kwargs.items()} | |
) | |
def send(self, name: str, *args, **kwargs): | |
'''Sends a single KaggleEvaluation request. | |
Args: | |
name: The endpoint name for the request. | |
*args: Variable-length/type arguments to be supplied on the request. | |
**kwargs: Key-value arguments to be supplied on the request. | |
Returns: | |
The response, which is of one of several allow-listed data types. | |
''' | |
request = self.serialize_request(name, *args, **kwargs) | |
response = self._send_with_deadline(request) | |
return _deserialize(response.payload) | |
def close(self): | |
self.channel.close() | |
### Server code | |
class KaggleEvaluationServiceServicer(kaggle_evaluation_grpc.KaggleEvaluationServiceServicer): | |
''' | |
Class which allows serving responses to KaggleEvaluation requests. The inference_server will run this service to listen for and respond | |
to requests from the Gateway. The Gateway may also listen for requests from the inference_server in some cases. | |
''' | |
def __init__(self, listeners: List[callable]): | |
self.listeners_map = dict((func.__name__, func) for func in listeners) | |
# pylint: disable=unused-argument | |
def Send(self, request: kaggle_evaluation_proto.KaggleEvaluationRequest, context: grpc.ServicerContext) -> kaggle_evaluation_proto.KaggleEvaluationResponse: | |
'''Handler for gRPC requests that deserializes arguments, calls a user-registered function for handling the | |
requested endpoint, then serializes and returns the response. | |
Args: | |
request: The KaggleEvaluationRequest protobuf message. | |
context: (Unused) gRPC context. | |
Returns: | |
The KaggleEvaluationResponse protobuf message. | |
Raises: | |
NotImplementedError if the caller has not registered a handler for the requested endpoint. | |
''' | |
if request.name not in self.listeners_map: | |
raise NotImplementedError(f'No listener for {request.name} was registered.') | |
args = map(_deserialize, request.args) | |
kwargs = {key: _deserialize(value) for key, value in request.kwargs.items()} | |
response_function = self.listeners_map[request.name] | |
response_payload = _serialize(response_function(*args, **kwargs)) | |
return kaggle_evaluation_proto.KaggleEvaluationResponse(payload=response_payload) | |
def define_server(*endpoint_listeners: Tuple[Callable]) -> grpc.server: | |
'''Registers the endpoints that the container is able to respond to, then starts a server which listens for | |
those endpoints. The endpoints that need to be implemented will depend on the specific competition. | |
Args: | |
endpoint_listeners: Tuple of functions that define how requests to the endpoint of the function name should be | |
handled. | |
Returns: | |
The gRPC server object, which has been started. It should be stopped at exit time. | |
Raises: | |
ValueError if parameter values are invalid. | |
''' | |
if not endpoint_listeners: | |
raise ValueError('Must pass at least one endpoint listener, e.g. `predict`') | |
for func in endpoint_listeners: | |
if not isinstance(func, Callable): | |
raise ValueError('Endpoint listeners passed to `serve` must be functions') | |
if func.__name__ == '<lambda>': | |
raise ValueError('Functions passed as endpoint listeners must be named') | |
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1), options=_GRPC_CHANNEL_OPTIONS) | |
kaggle_evaluation_grpc.add_KaggleEvaluationServiceServicer_to_server(KaggleEvaluationServiceServicer(endpoint_listeners), server) | |
server.add_insecure_port(f'[::]:{_GRPC_PORT}') | |
return server | |