kashif's picture
kashif HF staff
Upload 10 files
45e60de
import time
from typing import List, Optional
import pandas as pd
from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast, QuantileForecast
from .abstract import AbstractPredictor
class AutoGluonPredictor(AbstractPredictor):
def __init__(
self,
prediction_length: int,
freq: str,
seasonality: int,
time_limit: Optional[int] = None,
presets: str = "high_quality",
eval_metric: str = "MASE",
seed: int = 1,
enable_ensemble: bool = True,
hyperparameters: Optional[dict] = None,
**kwargs
):
super().__init__(prediction_length, freq, seasonality)
self.presets = presets
self.eval_metric = eval_metric
self.time_limit = time_limit
self.seed = seed
self.enable_ensemble = enable_ensemble
self.hyperparameters = hyperparameters
def fit_predict(self, dataset: Dataset) -> List[Forecast]:
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor
train_data = TimeSeriesDataFrame(dataset)
predictor = TimeSeriesPredictor(
prediction_length=self.prediction_length,
eval_metric=self.eval_metric,
eval_metric_seasonal_period=self.seasonality,
quantile_levels=self.quantile_levels,
)
start_time = time.time()
predictor.fit(
train_data,
time_limit=self.time_limit,
presets=self.presets,
random_seed=self.seed,
enable_ensemble=self.enable_ensemble,
hyperparameters=self.hyperparameters,
)
predictions = predictor.predict(train_data)
self.save_runtime(time.time() - start_time)
return self._predictions_df_to_gluonts_forecast(
predictions_df=predictions.drop("mean", axis=1), dataset=dataset
)
def _predictions_df_to_gluonts_forecast(
self, predictions_df, dataset: Dataset
) -> List[Forecast]:
agts_forecasts = [
f.droplevel("item_id")
for _, f in predictions_df.groupby(level="item_id", sort=False)
]
forecast_list = []
for ts, f in zip(dataset, agts_forecasts):
item_id = ts["item_id"]
forecast_list.append(
QuantileForecast(
forecast_arrays=f.values.T,
forecast_keys=f.columns,
start_date=pd.Period(f.index[0], freq=self.freq),
item_id=item_id,
)
)
return forecast_list