time-series-score / src /fit_model.py
kashif's picture
kashif HF staff
Upload 10 files
45e60de
from gluonts.dataset.common import Dataset
from .models import (
AbstractPredictor,
AutoGluonPredictor,
AutoPyTorchPredictor,
DeepARPredictor,
TFTPredictor,
AutoARIMAPredictor,
AutoETSPredictor,
AutoThetaPredictor,
StatsEnsemblePredictor,
)
MODEL_NAME_TO_CLASS = {
"autogluon": AutoGluonPredictor,
"autopytorch": AutoPyTorchPredictor,
"deepar": DeepARPredictor,
"tft": TFTPredictor,
"autoarima": AutoARIMAPredictor,
"autoets": AutoETSPredictor,
"autotheta": AutoThetaPredictor,
"statsensemble": StatsEnsemblePredictor,
}
def fit_predict_with_model(
model_name: str,
dataset: Dataset,
prediction_length: int,
freq: str,
seasonality: int,
**model_kwargs,
):
model_class = MODEL_NAME_TO_CLASS[model_name.lower()]
model: AbstractPredictor = model_class(
prediction_length=prediction_length,
freq=freq,
seasonality=seasonality,
**model_kwargs,
)
predictions = model.fit_predict(dataset)
info = {"run_time": model.get_runtime()}
return predictions, info