Source code for pytorch_forecasting.models.nbeats._nbeatskan_pkg

"""NBeatsKAN package container."""

from pytorch_forecasting.models.base._base_object import _BasePtForecaster


[docs] class NBeatsKAN_pkg(_BasePtForecaster): """NBeatsKAN package container.""" _tags = { "info:name": "NBeatsKAN", "info:compute": 1, "info:pred_type": ["point"], "info:y_type": ["numeric"], "authors": ["Sohaib-Ahmed21"], "capability:exogenous": False, "capability:multivariate": False, "capability:pred_int": False, "capability:flexible_history_length": False, "capability:cold_start": False, }
[docs] @classmethod def get_cls(cls): """Get model class.""" from pytorch_forecasting.models import NBeatsKAN return NBeatsKAN
[docs] @classmethod def get_base_test_params(cls): """Return testing parameter settings for the trainer. Returns ------- params : dict or list of dict, default = {} Parameters to create testing instances of the class Each dict are parameters to construct an "interesting" test instance, i.e., `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. `create_test_instance` uses the first (or only) dictionary in `params` """ return [ {"backcast_loss_ratio": 0.0}, # pure forecast loss {"backcast_loss_ratio": 1.0}, # equal forecast/backcast { "stack_types": ["generic"], "expansion_coefficient_lengths": [16], }, { "num_blocks": [1, 2], "num_block_layers": [2, 3], }, # varying block structure { "num": 7, "k": 4, "sparse_init": True, "grid_range": [-0.5, 0.5], "sp_trainable": False, }, # complex KAN config ]
@classmethod def _get_test_dataloaders_from(cls, params): loss = params.get("loss", None) data_loader_kwargs = params.get("data_loader_kwargs", {}) from pytorch_forecasting.metrics import TweedieLoss from pytorch_forecasting.tests._data_scenarios import ( data_with_covariates, dataloaders_fixed_window_without_covariates, make_dataloaders, ) if isinstance(loss, TweedieLoss): dwc = data_with_covariates() dl_default_kwargs = dict( target="target", time_varying_unknown_reals=["target"], add_relative_time_idx=False, ) dl_default_kwargs.update(data_loader_kwargs) dataloaders_with_covariates = make_dataloaders(dwc, **dl_default_kwargs) return dataloaders_with_covariates return dataloaders_fixed_window_without_covariates()