File size: 1,100 Bytes
45e60de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from gluonts.dataset.common import Dataset

from .models import (
    AbstractPredictor,
    AutoGluonPredictor,
    AutoPyTorchPredictor,
    DeepARPredictor,
    TFTPredictor,
    AutoARIMAPredictor,
    AutoETSPredictor,
    AutoThetaPredictor,
    StatsEnsemblePredictor,
)

MODEL_NAME_TO_CLASS = {
    "autogluon": AutoGluonPredictor,
    "autopytorch": AutoPyTorchPredictor,
    "deepar": DeepARPredictor,
    "tft": TFTPredictor,
    "autoarima": AutoARIMAPredictor,
    "autoets": AutoETSPredictor,
    "autotheta": AutoThetaPredictor,
    "statsensemble": StatsEnsemblePredictor,
}


def fit_predict_with_model(
    model_name: str,
    dataset: Dataset,
    prediction_length: int,
    freq: str,
    seasonality: int,
    **model_kwargs,
):
    model_class = MODEL_NAME_TO_CLASS[model_name.lower()]
    model: AbstractPredictor = model_class(
        prediction_length=prediction_length,
        freq=freq,
        seasonality=seasonality,
        **model_kwargs,
    )
    predictions = model.fit_predict(dataset)
    info = {"run_time": model.get_runtime()}
    return predictions, info