Source code for pytorch_forecasting.models.xlstm._xlstm_pkg

"""xLSTMTime package container."""

from pytorch_forecasting.models.base._base_object import _BasePtForecaster


[docs] class xLSTMTime_pkg(_BasePtForecaster): """xLSTMTime package container.""" _tags = { "info:name": "xLSTMTime", "info:compute": 3, "info:pred_type": ["point"], "info:y_type": ["numeric"], "authors": ["muslehal", "phoeenniixx"], "capability:exogenous": True, "capability:multivariate": True, "capability:pred_int": False, "capability:flexible_history_length": True, "capability:cold_start": False, }
[docs] @classmethod def get_cls(cls): """Get model class.""" from pytorch_forecasting.models import xLSTMTime return xLSTMTime
[docs] @classmethod def get_base_test_params(cls): """ Return testing parameter settings for the trainer. Returns ------- params : dict or list of dict, default = {} Parameters to create testing instances of the class """ params = [ {}, {"xlstm_type": "mlstm"}, {"num_layers": 2}, {"xlstm_type": "slstm", "input_projection_size": 32}, { "xlstm_type": "mlstm", "decomposition_kernel": 13, "dropout": 0.2, }, ] defaults = {"hidden_size": 32, "input_size": 1, "output_size": 1} for param in params: param.update(defaults) return params
@classmethod def _get_test_dataloaders_from(cls, params): """ Get dataloaders from parameters. Parameters ---------- params: dict Parameters to create dataloaders. One of the elements in the list returned by ``get_test_train_params``. Returns ------- dataloaders: Dict[str, DataLoader] Dict of dataloaders created from the parameters. Train, validation, and test dataloaders created from the parameters. """ from pytorch_forecasting.tests._data_scenarios import ( dataloaders_fixed_window_without_covariates, ) return dataloaders_fixed_window_without_covariates()