File size: 2,087 Bytes
15369ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import inspect
import pathlib
from types import ModuleType

from kaggle_evaluation.core import relay, templates
from kaggle_evaluation.svg_gateway import SVGGateway


def test(model_cls: type, data_path: str | pathlib.Path | None = None) -> None:
    '''Tests this competition's inference loop over the given Model class.
    
    The provided Model class should have a `predict` function which accepts input(s)
    and returns output(s) with the shapes and types required by this competition.
    This function performs best-effort validation of this by running an inference
    loop with a dummy test set over Model.predict.
    By default the test set is taken from the `kaggle_evaluation` directory, but you
    may override to another directory with the same test file structure via the
    `data_path` arg.'''
    print('Creating Model instance...')
    model = model_cls()
    if not hasattr(model, 'predict') or not inspect.ismethod(model.predict):
        msg = f'Model does not have method predict.'
        raise ValueError(msg)

    print('Running inference tests...')
    server = relay.define_server(model.predict)
    server.start()
    try:
        gateway = SVGGateway(data_path)
        submission_path = gateway.run()
        print(f'Wrote test submission file to "{str(submission_path)}".')
    except Exception as err:
        raise err from None
    finally:
        server.stop(0)

    print('Success!')


def _run_gateway() -> None:
    '''Internal function for running the Gateway during a Kaggle scoring session.
    
    Starts a scoring session which assumes existence of an Inference Server to return
    inferences over the test set.'''
    gateway = SVGGateway()
    gateway.run()


def _run_inference_server(module: ModuleType) -> None:
    '''Internal function for running the Inference Server during a Kaggle scoring session.
    
    Takes the user's submitted, imported module and sets up the inference server exposing
    their required method(s).'''
    model = module.Model()
    server = templates.InferenceServer(model.predict)
    server.serve()