|
import time |
|
from typing import List, Optional |
|
import pandas as pd |
|
|
|
from gluonts.dataset.common import Dataset |
|
from gluonts.model.forecast import Forecast, QuantileForecast |
|
|
|
from .abstract import AbstractPredictor |
|
|
|
|
|
class AutoGluonPredictor(AbstractPredictor): |
|
def __init__( |
|
self, |
|
prediction_length: int, |
|
freq: str, |
|
seasonality: int, |
|
time_limit: Optional[int] = None, |
|
presets: str = "high_quality", |
|
eval_metric: str = "MASE", |
|
seed: int = 1, |
|
enable_ensemble: bool = True, |
|
hyperparameters: Optional[dict] = None, |
|
**kwargs |
|
): |
|
super().__init__(prediction_length, freq, seasonality) |
|
self.presets = presets |
|
self.eval_metric = eval_metric |
|
self.time_limit = time_limit |
|
self.seed = seed |
|
self.enable_ensemble = enable_ensemble |
|
self.hyperparameters = hyperparameters |
|
|
|
def fit_predict(self, dataset: Dataset) -> List[Forecast]: |
|
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor |
|
|
|
train_data = TimeSeriesDataFrame(dataset) |
|
predictor = TimeSeriesPredictor( |
|
prediction_length=self.prediction_length, |
|
eval_metric=self.eval_metric, |
|
eval_metric_seasonal_period=self.seasonality, |
|
quantile_levels=self.quantile_levels, |
|
) |
|
start_time = time.time() |
|
predictor.fit( |
|
train_data, |
|
time_limit=self.time_limit, |
|
presets=self.presets, |
|
random_seed=self.seed, |
|
enable_ensemble=self.enable_ensemble, |
|
hyperparameters=self.hyperparameters, |
|
) |
|
predictions = predictor.predict(train_data) |
|
self.save_runtime(time.time() - start_time) |
|
return self._predictions_df_to_gluonts_forecast( |
|
predictions_df=predictions.drop("mean", axis=1), dataset=dataset |
|
) |
|
|
|
def _predictions_df_to_gluonts_forecast( |
|
self, predictions_df, dataset: Dataset |
|
) -> List[Forecast]: |
|
agts_forecasts = [ |
|
f.droplevel("item_id") |
|
for _, f in predictions_df.groupby(level="item_id", sort=False) |
|
] |
|
forecast_list = [] |
|
for ts, f in zip(dataset, agts_forecasts): |
|
item_id = ts["item_id"] |
|
forecast_list.append( |
|
QuantileForecast( |
|
forecast_arrays=f.values.T, |
|
forecast_keys=f.columns, |
|
start_date=pd.Period(f.index[0], freq=self.freq), |
|
item_id=item_id, |
|
) |
|
) |
|
return forecast_list |
|
|