Spaces:
Running
on
T4
Running
on
T4
| # Copyright (c) NXAI GmbH. | |
| # This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
| from abc import ABC, abstractmethod | |
| from typing import Literal | |
| import torch | |
| from .standard_adapter import ContextType, get_batches | |
| try: | |
| from .gluon import format_gluonts_output, get_gluon_batches | |
| _GLUONTS_AVAILABLE = True | |
| except ImportError: | |
| _GLUONTS_AVAILABLE = False | |
| try: | |
| from .hf_data import get_hfdata_batches | |
| _HF_DATASETS_AVAILABLE = True | |
| except ImportError: | |
| _HF_DATASETS_AVAILABLE = False | |
| DEF_TARGET_COLUMN = "target" | |
| DEF_META_COLUMNS = ("start", "item_id") | |
| def _format_output( | |
| quantiles: torch.Tensor, | |
| means: torch.Tensor, | |
| sample_meta: list[dict], | |
| quantile_levels: list[float], | |
| output_type: Literal["torch", "numpy", "gluonts"], | |
| ): | |
| if output_type == "torch": | |
| return quantiles.cpu(), means.cpu() | |
| elif output_type == "numpy": | |
| return quantiles.cpu().numpy(), means.cpu().numpy() | |
| elif output_type == "gluonts": | |
| if not _GLUONTS_AVAILABLE: | |
| raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!") | |
| return format_gluonts_output(quantiles, means, sample_meta, quantile_levels) | |
| else: | |
| raise ValueError(f"Invalid output type: {output_type}") | |
| def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs): | |
| for batch_ctx, batch_meta in batches: | |
| quantiles, mean = fc_func(batch_ctx, **predict_kwargs) | |
| yield _format_output( | |
| quantiles=quantiles, | |
| means=mean, | |
| sample_meta=batch_meta, | |
| quantile_levels=quantile_levels, | |
| output_type=output_type, | |
| ) | |
| def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs): | |
| if yield_per_batch: | |
| return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs) | |
| prediction_q = [] | |
| prediction_m = [] | |
| sample_meta = [] | |
| for batch_ctx, batch_meta in batches: | |
| quantiles, mean = fc_func(batch_ctx, **predict_kwargs) | |
| prediction_q.append(quantiles) | |
| prediction_m.append(mean) | |
| sample_meta.extend(batch_meta) | |
| prediction_q = torch.cat(prediction_q, dim=0) | |
| prediction_m = torch.cat(prediction_m, dim=0) | |
| return _format_output( | |
| quantiles=prediction_q, | |
| means=prediction_m, | |
| sample_meta=sample_meta, | |
| quantile_levels=quantile_levels, | |
| output_type=output_type, | |
| ) | |
| def _common_forecast_doc(): | |
| common_doc = f""" | |
| This method takes historical context data as input and outputs probabilistic forecasts. | |
| Args: | |
| output_type (Literal["torch", "numpy", "gluonts"], optional): | |
| Specifies the desired format of the returned forecasts: | |
| - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|] | |
| - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|] | |
| - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects. | |
| Defaults to "torch". | |
| batch_size (int, optional): The number of time series instances to process concurrently by the model. | |
| Defaults to 512. Must be $>= 1$. | |
| quantile_levels (List[float], optional): Quantile levels for which predictions should be generated. | |
| Defaults to (0.1, 0.2, ..., 0.9). | |
| yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding | |
| forecasts batch by batch as they are computed. | |
| Defaults to `False`. | |
| **predict_kwargs: Additional keyword arguments that are passed directly to the underlying | |
| prediction mechanism of the pre-trained model. Refer to the model's | |
| internal prediction method documentation for available options. | |
| Returns: | |
| The return type depends on `output_type` and `yield_per_batch`: | |
| - If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item | |
| will correspond to a batch of forecasts in the format specified by `output_type`. | |
| - If `yield_per_batch` is `False`: A single object containing all forecasts. | |
| - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean). | |
| - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean). | |
| - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts. | |
| """ | |
| return common_doc | |
| class ForecastModel(ABC): | |
| def _forecast_quantiles(self, batch, **predict_kwargs): | |
| pass | |
| def forecast( | |
| self, | |
| context: ContextType, | |
| output_type: Literal["torch", "numpy", "gluonts"] = "torch", | |
| batch_size: int = 512, | |
| quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), | |
| yield_per_batch: bool = False, | |
| **predict_kwargs, | |
| ): | |
| f""" | |
| {_common_forecast_doc} | |
| Args: | |
| context (ContextType): The historical "context" data of the time series: | |
| - `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor | |
| - `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array | |
| - `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch) | |
| - `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch) | |
| """ | |
| assert batch_size >= 1, "Batch size must be >= 1" | |
| batches = get_batches(context, batch_size) | |
| return _gen_forecast( | |
| self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs | |
| ) | |
| def forecast_gluon( | |
| self, | |
| gluonDataset, | |
| output_type: Literal["torch", "numpy", "gluonts"] = "torch", | |
| batch_size: int = 512, | |
| quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), | |
| yield_per_batch: bool = False, | |
| data_kwargs: dict = {}, | |
| **predict_kwargs, | |
| ): | |
| f""" | |
| {_common_forecast_doc()} | |
| Args: | |
| gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the | |
| historical time series data. | |
| data_kwargs (dict, optional): Additional keyword arguments passed to the | |
| autogluon data processing function. | |
| """ | |
| assert batch_size >= 1, "Batch size must be >= 1" | |
| if not _GLUONTS_AVAILABLE: | |
| raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!") | |
| batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs) | |
| return _gen_forecast( | |
| self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs | |
| ) | |
| def forecast_hfdata( | |
| self, | |
| hf_dataset, | |
| output_type: Literal["torch", "numpy", "gluonts"] = "torch", | |
| batch_size: int = 512, | |
| quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), | |
| yield_per_batch: bool = False, | |
| data_kwargs: dict = {}, | |
| **predict_kwargs, | |
| ): | |
| f""" | |
| {_common_forecast_doc()} | |
| Args: | |
| hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the | |
| historical time series data. | |
| data_kwargs (dict, optional): Additional keyword arguments passed to the | |
| datasets data processing function. | |
| """ | |
| assert batch_size >= 1, "Batch size must be >= 1" | |
| if not _HF_DATASETS_AVAILABLE: | |
| raise ValueError( | |
| "forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!" | |
| ) | |
| batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs) | |
| return _gen_forecast( | |
| self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs | |
| ) | |