Spaces:
Paused
Paused
''' Lower level implementation details of the gateway. | |
Hosts should not need to review this file before writing their competition specific gateway. | |
''' | |
import enum | |
import json | |
import os | |
import pathlib | |
import re | |
import subprocess | |
import tempfile | |
from socket import gaierror | |
from typing import Any, List, Optional, Tuple, Union | |
import grpc | |
import numpy as np | |
import pandas as pd | |
import polars as pl | |
import kaggle_evaluation.core.relay | |
_FILE_SHARE_DIR = '/kaggle/shared/' | |
IS_RERUN = os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None | |
class GatewayRuntimeErrorType(enum.Enum): | |
''' Allow-listed error types that Gateways can raise, which map to canned error messages to show users.''' | |
UNSPECIFIED = 0 | |
SERVER_NEVER_STARTED = 1 | |
SERVER_CONNECTION_FAILED = 2 | |
SERVER_RAISED_EXCEPTION = 3 | |
SERVER_MISSING_ENDPOINT = 4 | |
# Default error type if an exception was raised that was not explicitly handled by the Gateway | |
GATEWAY_RAISED_EXCEPTION = 5 | |
INVALID_SUBMISSION = 6 | |
class GatewayRuntimeError(Exception): | |
''' Gateways can raise this error to capture a user-visible error enum from above and host-visible error details.''' | |
def __init__(self, error_type: GatewayRuntimeErrorType, error_details: Optional[str]=None): | |
self.error_type = error_type | |
self.error_details = error_details | |
class BaseGateway(): | |
def __init__(self, target_column_name: Optional[str]=None): | |
self.client = kaggle_evaluation.core.relay.Client('inference_server' if IS_RERUN else 'localhost') | |
self.server = None # The gateway can have a server but it isn't typically necessary. | |
self.target_column_name = target_column_name # Only used if the predictions are made as a primitive type (int, bool, etc) rather than a dataframe. | |
def validate_prediction_batch( | |
self, | |
prediction_batch: Any, | |
row_ids: Union[pl.DataFrame, pl.Series, pd.DataFrame, pd.Series] | |
): | |
''' If competitors can submit fewer rows than expected they can save all predictions for the last batch and | |
bypass the benefits of the Kaggle evaluation service. This attack was seen in a real competition with the older time series API: | |
https://www.kaggle.com/competitions/riiid-test-answer-prediction/discussion/196066 | |
It's critically important that this check be run every time predict() is called. | |
If your predictions may take a variable number of rows and you need to write a custom version of this check, | |
you still must specify a minimum row count greater than zero per prediction batch. | |
''' | |
if prediction_batch is None: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'No prediction received') | |
num_received_rows = None | |
# Special handling for numpy ints only as numpy floats are python floats, but numpy ints aren't python ints | |
for primitive_type in [int, float, str, bool, np.int_]: | |
if isinstance(prediction_batch, primitive_type): | |
# Types that only support one predictions per batch don't need to be validated. | |
# Basic types are valid for prediction, but either don't have a length (int) or the length isn't relevant for | |
# purposes of this check (str). | |
num_received_rows = 1 | |
if num_received_rows is None: | |
if type(prediction_batch) not in [pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, f'Invalid prediction data type, received: {type(prediction_batch)}') | |
num_received_rows = len(prediction_batch) | |
if type(row_ids) not in [pl.DataFrame, pl.Series, pd.DataFrame, pd.Series]: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, f'Invalid row ID type {type(row_ids)}; expected Polars DataFrame or similar') | |
num_expected_rows = len(row_ids) | |
if len(row_ids) == 0: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, 'Missing row IDs for batch') | |
if num_received_rows != num_expected_rows: | |
raise GatewayRuntimeError( | |
GatewayRuntimeErrorType.INVALID_SUBMISSION, | |
f'Invalid predictions: expected {num_expected_rows} rows but received {num_received_rows}' | |
) | |
def _standardize_and_validate_paths( | |
self, | |
input_paths: List[Union[str, pathlib.Path]] | |
) -> List[pathlib.Path]: | |
# Accept a list of str or pathlib.Path, but standardize on list of str | |
for path in input_paths: | |
if os.pardir in str(path): | |
raise ValueError(f'Send files path contains {os.pardir}: {path}') | |
if str(path) != str(os.path.normpath(path)): | |
# Raise an error rather than sending users unexpectedly altered paths | |
raise ValueError(f'Send files path {path} must be normalized. See `os.path.normpath`') | |
if type(path) not in (pathlib.Path, str): | |
raise ValueError('All paths must be of type str or pathlib.Path') | |
if not os.path.exists(path): | |
raise ValueError(f'Input path {path} does not exist') | |
input_paths = [os.path.abspath(path) for path in input_paths] | |
if len(set(input_paths)) != len(input_paths): | |
raise ValueError('Duplicate input paths found') | |
if not self.file_share_dir.endswith(os.path.sep): | |
# Ensure output dir is valid for later use | |
output_dir = self.file_share_dir + os.path.sep | |
if not os.path.exists(self.file_share_dir) or not os.path.isdir(self.file_share_dir): | |
raise ValueError(f'Invalid output directory {self.file_share_dir}') | |
# Can't use os.path.join for output_dir + path: os.path.join won't prepend to an abspath | |
output_paths = [output_dir + path for path in input_paths] | |
return input_paths, output_paths | |
def share_files( | |
self, | |
input_paths: List[Union[str, pathlib.Path]], | |
) -> List[str]: | |
''' Makes files and/or directories available to the user's inference_server. They will be mirrored under the | |
self.file_share_dir directory, using the full absolute path. An input like: | |
/kaggle/input/mycomp/test.csv | |
Would be written to: | |
/kaggle/shared/kaggle/input/mycomp/test.csv | |
Args: | |
input_paths: List of paths to files and/or directories that should be shared. | |
Returns: | |
The output paths that were shared. | |
Raises: | |
ValueError if any invalid paths are passed. | |
''' | |
input_paths, output_paths = self._standardize_and_validate_paths(input_paths) | |
for in_path, out_path in zip(input_paths, output_paths): | |
os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
# This makes the files available to the InferenceServer as read-only. Only the Gateway can mount files. | |
# mount will only work in live kaggle evaluation rerun sessions. Otherwise use a symlink. | |
if IS_RERUN: | |
if not os.path.isdir(out_path): | |
pathlib.Path(out_path).touch() | |
subprocess.run(f'mount --bind {in_path} {out_path}', shell=True, check=True) | |
else: | |
subprocess.run(f'ln -s {in_path} {out_path}', shell=True, check=True) | |
return output_paths | |
def write_submission(self, predictions, row_ids: List[Union[pl.Series, pl.DataFrame, pd.Series, pd.DataFrame]]) -> pathlib.Path: | |
''' Export the predictions to a submission file.''' | |
if isinstance(predictions, list): | |
if isinstance(predictions[0], pd.DataFrame): | |
predictions = pd.concat(predictions, ignore_index=True) | |
elif isinstance(predictions[0], pl.DataFrame): | |
try: | |
predictions = pl.concat(predictions, how='vertical_relaxed') | |
except pl.exceptions.SchemaError: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction types') | |
except pl.exceptions.ComputeError: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction column counts') | |
elif isinstance(predictions[0], pl.Series): | |
try: | |
predictions = pl.concat(predictions, how='vertical') | |
except pl.exceptions.SchemaError: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction types') | |
except pl.exceptions.ComputeError: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.INVALID_SUBMISSION, 'Inconsistent prediction column counts') | |
if type(row_ids[0]) in [pl.Series, pl.DataFrame]: | |
row_ids = pl.concat(row_ids) | |
elif type(row_ids[0]) in [pd.Series, pd.DataFrame]: | |
row_ids = pd.concat(row_ids).reset_index(drop=True) | |
else: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, f'Invalid row ID datatype {type(row_ids[0])}. Expected Polars series or dataframe.') | |
if self.target_column_name is None: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, '`target_column_name` must be set in order to use scalar value predictions.') | |
predictions = pl.DataFrame(data={row_ids.columns[0]: row_ids, self.target_column_name: predictions}) | |
submission_path = pathlib.Path('/kaggle/working/submission.csv') | |
if not IS_RERUN: | |
with tempfile.NamedTemporaryFile(prefix='kaggle-evaluation-submission-', suffix='.csv', delete=False, mode='w+') as f: | |
submission_path = pathlib.Path(f.name) | |
if isinstance(predictions, pd.DataFrame): | |
predictions.to_csv(submission_path, index=False) | |
elif isinstance(predictions, pl.DataFrame): | |
pl.DataFrame(predictions).write_csv(submission_path) | |
else: | |
raise ValueError(f"Unsupported predictions type {type(predictions)}; can't write submission file") | |
return submission_path | |
def write_result(self, error: Optional[GatewayRuntimeError]=None): | |
''' Export a result.json containing error details if applicable.''' | |
result = { 'Succeeded': error is None } | |
if error is not None: | |
result['ErrorType'] = error.error_type.value | |
result['ErrorName'] = error.error_type.name | |
# Max error detail length is 8000 | |
result['ErrorDetails'] = str(error.error_details[:8000]) if error.error_details else None | |
with open('result.json', 'w') as f_open: | |
json.dump(result, f_open) | |
def handle_server_error(self, exception: Exception, endpoint: str): | |
''' Determine how to handle an exception raised when calling the inference server. Typically just format the | |
error into a GatewayRuntimeError and raise. | |
''' | |
exception_str = str(exception) | |
if isinstance(exception, gaierror) or (isinstance(exception, RuntimeError) and 'Failed to connect to server after waiting' in exception_str): | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_NEVER_STARTED) from None | |
if f'No listener for {endpoint} was registered' in exception_str: | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_MISSING_ENDPOINT, f'Server did not register a listener for {endpoint}') from None | |
if 'Exception calling application' in exception_str: | |
# Extract just the exception message raised by the inference server | |
message_match = re.search('"Exception calling application: (.*)"', exception_str, re.IGNORECASE) | |
message = message_match.group(1) if message_match else exception_str | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_RAISED_EXCEPTION, message) from None | |
if isinstance(exception, grpc._channel._InactiveRpcError): | |
raise GatewayRuntimeError(GatewayRuntimeErrorType.SERVER_CONNECTION_FAILED, exception_str) from None | |
raise exception | |