File size: 4,627 Bytes
15369ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
'''Template for the two classes hosts should customize for each competition.'''

import abc
import os
import pathlib
import polars as pl
import time
import sys
import traceback
import warnings

from typing import Callable, Generator, Tuple

import kaggle_evaluation.core.base_gateway
import kaggle_evaluation.core.relay


_initial_import_time = time.time()
_issued_startup_time_warning = False


class Gateway(kaggle_evaluation.core.base_gateway.BaseGateway, abc.ABC):
    '''
    Template to start with when writing a new gateway.
    In most cases, hosts should only need to write get_all_predictions.
    There are two main methods for sending data to the inference_server hosts should understand:
    - Small datasets: use `self.predict`. Competitors will receive the data passed to self.predict as
    Python objects in memory. This is just a wrapper for self.client.send(); you can write additional
    wrappers if necessary.
    - Large datasets: it's much faster to send data via self.share_files, which is equivalent to making
    files available via symlink. See base_gateway.BaseGateway.share_files for the full details.
    '''

    @abc.abstractmethod
    def generate_data_batches(self) -> Generator:
        ''' Used by the default implementation of `get_all_predictions` so we can
        ensure `validate_prediction_batch` is run every time `predict` is called.

        This method must yield both the batch of data to be sent to `predict` and a series
        of row IDs to be sent to `validate_prediction_batch`.
        '''
        raise NotImplementedError

    def get_all_predictions(self):
        all_predictions = []
        all_row_ids = []
        for data_batch, row_ids in self.generate_data_batches():
            predictions = self.predict(*data_batch)
            predictions = pl.Series(self.target_column_name, predictions)
            self.validate_prediction_batch(predictions, row_ids)
            all_predictions.append(predictions)
            all_row_ids.append(row_ids)
        return all_predictions, all_row_ids

    def predict(self, *args, **kwargs):
        ''' self.predict will send all data in args and kwargs to the user container, and
        instruct the user container to generate a `predict` response.
        '''
        try:
            return self.client.send('predict', *args, **kwargs)
        except Exception as e:
            self.handle_server_error(e, 'predict')

    def set_response_timeout_seconds(self, timeout_seconds: float):
        # Also store timeout_seconds in an easy place for for competitor to access.
        self.timeout_seconds = timeout_seconds
        # Set a response deadline that will apply after the very first repsonse
        self.client.endpoint_deadline_seconds = timeout_seconds

    def run(self) -> pathlib.Path:
        error = None
        submission_path = None
        try:
            predictions, row_ids = self.get_all_predictions()
            submission_path = self.write_submission(predictions, row_ids)
        except kaggle_evaluation.core.base_gateway.GatewayRuntimeError as gre:
            error = gre
        except Exception:
            # Get the full stack trace
            exc_type, exc_value, exc_traceback = sys.exc_info()
            error_str = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))

            error = kaggle_evaluation.core.base_gateway.GatewayRuntimeError(
                kaggle_evaluation.core.base_gateway.GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION,
                error_str
            )

        self.client.close()
        if self.server:
            self.server.stop(0)

        if kaggle_evaluation.core.base_gateway.IS_RERUN:
            self.write_result(error)
        elif error:
            # For local testing
            raise error
        
        return submission_path


class InferenceServer(abc.ABC):
    '''
    Base class for competition participants to inherit from when writing their submission. In most cases, users should
    only need to implement a `predict` function or other endpoints to pass to this class's constructor, and hosts will
    provide a mock Gateway for testing.
    '''
    def __init__(self, endpoint_listeners: Tuple[Callable]):
        self.server = kaggle_evaluation.core.relay.define_server(endpoint_listeners)
        self.client = None  # The inference_server can have a client but it isn't typically necessary.

    def serve(self):
        self.server.start()
        if os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None:
            self.server.wait_for_termination()  # This will block all other code