File size: 2,615 Bytes
45e60de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import time
from typing import List, Optional
import pandas as pd
from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast, QuantileForecast
from .abstract import AbstractPredictor
class AutoGluonPredictor(AbstractPredictor):
def __init__(
self,
prediction_length: int,
freq: str,
seasonality: int,
time_limit: Optional[int] = None,
presets: str = "high_quality",
eval_metric: str = "MASE",
seed: int = 1,
enable_ensemble: bool = True,
hyperparameters: Optional[dict] = None,
**kwargs
):
super().__init__(prediction_length, freq, seasonality)
self.presets = presets
self.eval_metric = eval_metric
self.time_limit = time_limit
self.seed = seed
self.enable_ensemble = enable_ensemble
self.hyperparameters = hyperparameters
def fit_predict(self, dataset: Dataset) -> List[Forecast]:
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor
train_data = TimeSeriesDataFrame(dataset)
predictor = TimeSeriesPredictor(
prediction_length=self.prediction_length,
eval_metric=self.eval_metric,
eval_metric_seasonal_period=self.seasonality,
quantile_levels=self.quantile_levels,
)
start_time = time.time()
predictor.fit(
train_data,
time_limit=self.time_limit,
presets=self.presets,
random_seed=self.seed,
enable_ensemble=self.enable_ensemble,
hyperparameters=self.hyperparameters,
)
predictions = predictor.predict(train_data)
self.save_runtime(time.time() - start_time)
return self._predictions_df_to_gluonts_forecast(
predictions_df=predictions.drop("mean", axis=1), dataset=dataset
)
def _predictions_df_to_gluonts_forecast(
self, predictions_df, dataset: Dataset
) -> List[Forecast]:
agts_forecasts = [
f.droplevel("item_id")
for _, f in predictions_df.groupby(level="item_id", sort=False)
]
forecast_list = []
for ts, f in zip(dataset, agts_forecasts):
item_id = ts["item_id"]
forecast_list.append(
QuantileForecast(
forecast_arrays=f.values.T,
forecast_keys=f.columns,
start_date=pd.Period(f.index[0], freq=self.freq),
item_id=item_id,
)
)
return forecast_list
|