|
from gluonts.dataset.common import Dataset |
|
|
|
from .models import ( |
|
AbstractPredictor, |
|
AutoGluonPredictor, |
|
AutoPyTorchPredictor, |
|
DeepARPredictor, |
|
TFTPredictor, |
|
AutoARIMAPredictor, |
|
AutoETSPredictor, |
|
AutoThetaPredictor, |
|
StatsEnsemblePredictor, |
|
) |
|
|
|
MODEL_NAME_TO_CLASS = { |
|
"autogluon": AutoGluonPredictor, |
|
"autopytorch": AutoPyTorchPredictor, |
|
"deepar": DeepARPredictor, |
|
"tft": TFTPredictor, |
|
"autoarima": AutoARIMAPredictor, |
|
"autoets": AutoETSPredictor, |
|
"autotheta": AutoThetaPredictor, |
|
"statsensemble": StatsEnsemblePredictor, |
|
} |
|
|
|
|
|
def fit_predict_with_model( |
|
model_name: str, |
|
dataset: Dataset, |
|
prediction_length: int, |
|
freq: str, |
|
seasonality: int, |
|
**model_kwargs, |
|
): |
|
model_class = MODEL_NAME_TO_CLASS[model_name.lower()] |
|
model: AbstractPredictor = model_class( |
|
prediction_length=prediction_length, |
|
freq=freq, |
|
seasonality=seasonality, |
|
**model_kwargs, |
|
) |
|
predictions = model.fit_predict(dataset) |
|
info = {"run_time": model.get_runtime()} |
|
return predictions, info |
|
|