import copy import multiprocessing as mp import time from typing import List from gluonts.dataset.common import Dataset from gluonts.model.forecast import Forecast, SampleForecast from .abstract import AbstractPredictor class AutoPyTorchPredictor(AbstractPredictor): def __init__( self, prediction_length: int, freq: str, seasonality: int, time_limit: int = 6 * 60 * 60, optimize_metric: str = "mean_MASE_forecasting", seed: int = 1, **kwargs ): super().__init__(prediction_length, freq, seasonality) self.optimize_metric = optimize_metric self.run_time = time_limit self.seed = seed def fit_predict(self, dataset: Dataset) -> List[Forecast]: from autoPyTorch.api.time_series_forecasting import TimeSeriesForecastingTask from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes y_train = [item["target"] for item in dataset] start_times = [item["start"].to_timestamp(how="S") for item in dataset] api = TimeSeriesForecastingTask( seed=self.seed, ensemble_size=20, resampling_strategy=HoldoutValTypes.time_series_hold_out_validation, resampling_strategy_args=None, ) api.set_pipeline_options(early_stopping=20, torch_num_threads=mp.cpu_count()) fit_start_time = time.time() api.search( X_train=None, y_train=copy.deepcopy(y_train), optimize_metric=self.optimize_metric, n_prediction_steps=self.prediction_length, memory_limit=16 * 1024, freq="1" + self.freq, start_times=start_times, normalize_y=False, total_walltime_limit=self.run_time, min_num_test_instances=1000, budget_type="epochs", max_budget=50, min_budget=5, ) # # Skip refitting as this raises exceptions for all models as of v0.2.1 # refit_dataset = api.dataset.create_refit_set() # api.refit(refit_dataset, 0) # Predict for the test set test_sets = api.dataset.generate_test_seqs() predictions = api.predict(test_sets) self.save_runtime(time.time() - fit_start_time) forecast_list = [] for ts, pred in zip(dataset, predictions): forecast_list.append( SampleForecast( samples=pred[None], start_date=ts["start"] + len(ts["target"]), item_id=ts["item_id"], ) ) return forecast_list