Source code for pytorch_forecasting.models.timexer._timexer_pkg_v2

"""
Metadata container for TimeXer v2.
"""

from pytorch_forecasting.base._base_pkg import Base_pkg


[docs] class TimeXer_pkg_v2(Base_pkg): """TimeXer metadata container.""" _tags = { "info:name": "TimeXer", "authors": ["PranavBhatP"], "capability:exogenous": True, "capability:multivariate": True, "capability:pred_int": True, "capability:flexible_history_length": False, }
[docs] @classmethod def get_cls(cls): """Get model class.""" from pytorch_forecasting.models.timexer._timexer_v2 import TimeXer return TimeXer
[docs] @classmethod def get_datamodule_cls(cls): """Get the underlying DataModule class.""" from pytorch_forecasting.data._tslib_data_module import TslibDataModule return TslibDataModule
[docs] @classmethod def get_test_train_params(cls): """Return testing parameter settings for the trainer. Returns ------- params : dict or list of dict, default = {} Parameters to create testing instances of the class Each dict are parameters to construct an "interesting" test instance, i.e., `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. `create_test_instance` uses the first (or only) dictionary in `params` """ from pytorch_forecasting.metrics import QuantileLoss params = [ {}, dict( hidden_size=64, n_heads=4, ), dict(datamodule_cfg=dict(context_length=12, prediction_length=3)), dict( hidden_size=32, n_heads=2, datamodule_cfg=dict( context_length=12, prediction_length=3, add_relative_time_idx=False, ), ), dict( hidden_size=128, patch_length=12, datamodule_cfg=dict(context_length=16, prediction_length=4), ), dict( n_heads=2, e_layers=1, patch_length=6, ), dict( hidden_size=256, n_heads=8, e_layers=3, d_ff=1024, patch_length=8, factor=3, activation="gelu", dropout=0.2, ), dict( hidden_size=32, n_heads=2, e_layers=1, d_ff=64, patch_length=4, factor=2, activation="relu", dropout=0.05, datamodule_cfg=dict( context_length=16, prediction_length=4, ), loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), ), dict( hidden_size=32, patch_length=1, n_heads=4, e_layers=1, d_ff=32, dropout=0.1, use_efficient_attention=True, ), ] default_dm_cfg = {"context_length": 12, "prediction_length": 4} for param in params: current_dm_cfg = param.get("datamodule_cfg", {}) default_dm_cfg.update(current_dm_cfg) param["datamodule_cfg"] = default_dm_cfg return params