|
from typing import List, Optional |
|
|
|
from gluonts.dataset.common import Dataset |
|
from gluonts.model.forecast import Forecast |
|
|
|
|
|
class AbstractPredictor: |
|
def __init__( |
|
self, |
|
prediction_length: int, |
|
freq: str, |
|
seasonality: int, |
|
quantile_levels: Optional[List[float]] = None, |
|
): |
|
self.prediction_length = prediction_length |
|
self.freq = freq |
|
self.seasonality = seasonality |
|
self.quantile_levels = quantile_levels or [ |
|
0.1, |
|
0.2, |
|
0.3, |
|
0.4, |
|
0.5, |
|
0.6, |
|
0.7, |
|
0.8, |
|
0.9, |
|
] |
|
self._runtime = None |
|
|
|
def fit_predict( |
|
self, |
|
dataset: Dataset |
|
) -> List[Forecast]: |
|
raise NotImplementedError |
|
|
|
def save_runtime(self, time: float) -> None: |
|
self._runtime = time |
|
|
|
def get_runtime(self) -> float: |
|
return self._runtime |
|
|