import time from typing import List, Optional from datetime import timedelta from gluonts.dataset.split import split from gluonts.dataset.common import Dataset from gluonts.model.forecast import Forecast from gluonts.torch.model.estimator import Estimator from .abstract import AbstractPredictor class GluonTSPredictor(AbstractPredictor): def __init__( self, prediction_length: int, freq: str, seasonality: int, time_limit: Optional[int] = None, **kwargs, ): super().__init__(prediction_length, freq, seasonality) self.time_limit = time_limit def fit_predict(self, dataset: Dataset) -> List[Forecast]: estimator = self._get_estimator() train_data, _ = split(dataset, offset=-self.prediction_length) fit_start_time = time.time() predictor = estimator.train(training_data=train_data, validation_data=dataset) predictions = predictor.predict(dataset) self.save_runtime(time.time() - fit_start_time) return predictions def _get_estimator(self) -> Estimator: raise NotImplementedError def _get_trainer_kwargs(self): from pytorch_lightning.callbacks import Timer # Train until time limit return {"max_epochs": 100_000, "callbacks": [Timer(timedelta(seconds=self.time_limit))]} class DeepARPredictor(GluonTSPredictor): def _get_estimator(self) -> Estimator: from gluonts.torch.model.deepar import DeepAREstimator return DeepAREstimator( freq=self.freq, prediction_length=self.prediction_length, trainer_kwargs=self._get_trainer_kwargs(), ) class TFTPredictor(GluonTSPredictor): def _get_estimator(self) -> Estimator: from gluonts.torch.model.tft import TemporalFusionTransformerEstimator return TemporalFusionTransformerEstimator( freq=self.freq, prediction_length=self.prediction_length, trainer_kwargs=self._get_trainer_kwargs(), )