kashif's picture
kashif HF staff
Upload 10 files
45e60de
raw
history blame
2.04 kB
import time
from typing import List, Optional
from datetime import timedelta
from gluonts.dataset.split import split
from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast
from gluonts.torch.model.estimator import Estimator
from .abstract import AbstractPredictor
class GluonTSPredictor(AbstractPredictor):
def __init__(
self,
prediction_length: int,
freq: str,
seasonality: int,
time_limit: Optional[int] = None,
**kwargs,
):
super().__init__(prediction_length, freq, seasonality)
self.time_limit = time_limit
def fit_predict(self, dataset: Dataset) -> List[Forecast]:
estimator = self._get_estimator()
train_data, _ = split(dataset, offset=-self.prediction_length)
fit_start_time = time.time()
predictor = estimator.train(training_data=train_data, validation_data=dataset)
predictions = predictor.predict(dataset)
self.save_runtime(time.time() - fit_start_time)
return predictions
def _get_estimator(self) -> Estimator:
raise NotImplementedError
def _get_trainer_kwargs(self):
from pytorch_lightning.callbacks import Timer
# Train until time limit
return {"max_epochs": 100_000, "callbacks": [Timer(timedelta(seconds=self.time_limit))]}
class DeepARPredictor(GluonTSPredictor):
def _get_estimator(self) -> Estimator:
from gluonts.torch.model.deepar import DeepAREstimator
return DeepAREstimator(
freq=self.freq,
prediction_length=self.prediction_length,
trainer_kwargs=self._get_trainer_kwargs(),
)
class TFTPredictor(GluonTSPredictor):
def _get_estimator(self) -> Estimator:
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator
return TemporalFusionTransformerEstimator(
freq=self.freq,
prediction_length=self.prediction_length,
trainer_kwargs=self._get_trainer_kwargs(),
)