Source code for pytorch_forecasting.models.nbeats._nbeats_pkg
"""NBeats package container."""
from pytorch_forecasting.models.base._base_object import _BasePtForecaster
[docs]
class NBeats_pkg(_BasePtForecaster):
"""NBeats package container."""
_tags = {
"info:name": "NBeats",
"info:compute": 1,
"authors": ["jdb78"],
"capability:exogenous": False,
"capability:multivariate": False,
"capability:pred_int": False,
"capability:flexible_history_length": False,
"capability:cold_start": False,
}
[docs]
@classmethod
def get_model_cls(cls):
"""Get model class."""
from pytorch_forecasting.models import NBeats
return NBeats
[docs]
@classmethod
def get_test_train_params(cls):
"""Return testing parameter settings for the trainer.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
return [{"backcast_loss_ratio": 1.0}]
@classmethod
def _get_test_dataloaders_from(cls, params):
"""Get dataloaders from parameters.
Parameters
----------
params : dict
Parameters to create dataloaders.
One of the elements in the list returned by ``get_test_train_params``.
Returns
-------
dataloaders : dict with keys "train", "val", "test", values torch DataLoader
Dict of dataloaders created from the parameters.
Train, validation, and test dataloaders, in this order.
"""
from pytorch_forecasting.tests._data_scenarios import (
dataloaders_fixed_window_without_covariates,
)
return dataloaders_fixed_window_without_covariates()