|
import time |
|
from typing import List, Optional |
|
from datetime import timedelta |
|
|
|
from gluonts.dataset.split import split |
|
from gluonts.dataset.common import Dataset |
|
from gluonts.model.forecast import Forecast |
|
from gluonts.torch.model.estimator import Estimator |
|
|
|
|
|
from .abstract import AbstractPredictor |
|
|
|
|
|
class GluonTSPredictor(AbstractPredictor): |
|
def __init__( |
|
self, |
|
prediction_length: int, |
|
freq: str, |
|
seasonality: int, |
|
time_limit: Optional[int] = None, |
|
**kwargs, |
|
): |
|
super().__init__(prediction_length, freq, seasonality) |
|
self.time_limit = time_limit |
|
|
|
def fit_predict(self, dataset: Dataset) -> List[Forecast]: |
|
estimator = self._get_estimator() |
|
train_data, _ = split(dataset, offset=-self.prediction_length) |
|
fit_start_time = time.time() |
|
predictor = estimator.train(training_data=train_data, validation_data=dataset) |
|
predictions = predictor.predict(dataset) |
|
self.save_runtime(time.time() - fit_start_time) |
|
return predictions |
|
|
|
def _get_estimator(self) -> Estimator: |
|
raise NotImplementedError |
|
|
|
def _get_trainer_kwargs(self): |
|
from pytorch_lightning.callbacks import Timer |
|
|
|
|
|
return {"max_epochs": 100_000, "callbacks": [Timer(timedelta(seconds=self.time_limit))]} |
|
|
|
|
|
class DeepARPredictor(GluonTSPredictor): |
|
def _get_estimator(self) -> Estimator: |
|
from gluonts.torch.model.deepar import DeepAREstimator |
|
|
|
return DeepAREstimator( |
|
freq=self.freq, |
|
prediction_length=self.prediction_length, |
|
trainer_kwargs=self._get_trainer_kwargs(), |
|
) |
|
|
|
|
|
class TFTPredictor(GluonTSPredictor): |
|
def _get_estimator(self) -> Estimator: |
|
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator |
|
|
|
return TemporalFusionTransformerEstimator( |
|
freq=self.freq, |
|
prediction_length=self.prediction_length, |
|
trainer_kwargs=self._get_trainer_kwargs(), |
|
) |
|
|