File size: 2,042 Bytes
45e60de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import time
from typing import List, Optional
from datetime import timedelta

from gluonts.dataset.split import split
from gluonts.dataset.common import Dataset
from gluonts.model.forecast import Forecast
from gluonts.torch.model.estimator import Estimator


from .abstract import AbstractPredictor


class GluonTSPredictor(AbstractPredictor):
    def __init__(
        self,
        prediction_length: int,
        freq: str,
        seasonality: int,
        time_limit: Optional[int] = None,
        **kwargs,
    ):
        super().__init__(prediction_length, freq, seasonality)
        self.time_limit = time_limit

    def fit_predict(self, dataset: Dataset) -> List[Forecast]:
        estimator = self._get_estimator()
        train_data, _ = split(dataset, offset=-self.prediction_length)
        fit_start_time = time.time()
        predictor = estimator.train(training_data=train_data, validation_data=dataset)
        predictions = predictor.predict(dataset)
        self.save_runtime(time.time() - fit_start_time)
        return predictions

    def _get_estimator(self) -> Estimator:
        raise NotImplementedError

    def _get_trainer_kwargs(self):
        from pytorch_lightning.callbacks import Timer

        # Train until time limit
        return {"max_epochs": 100_000, "callbacks": [Timer(timedelta(seconds=self.time_limit))]}


class DeepARPredictor(GluonTSPredictor):
    def _get_estimator(self) -> Estimator:
        from gluonts.torch.model.deepar import DeepAREstimator

        return DeepAREstimator(
            freq=self.freq,
            prediction_length=self.prediction_length,
            trainer_kwargs=self._get_trainer_kwargs(),
        )


class TFTPredictor(GluonTSPredictor):
    def _get_estimator(self) -> Estimator:
        from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

        return TemporalFusionTransformerEstimator(
            freq=self.freq,
            prediction_length=self.prediction_length,
            trainer_kwargs=self._get_trainer_kwargs(),
        )