Source code for pytorch_forecasting.models.xlstm._xlstm

from copy import copy
from typing import Literal, Optional, Union

import torch
from torch import nn

from pytorch_forecasting.layers import SeriesDecomposition, mLSTMNetwork, sLSTMNetwork
from pytorch_forecasting.metrics import SMAPE, Metric
from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel


[docs] class xLSTMTime(AutoRegressiveBaseModel): """ xLSTMTime is a long‑term time series forecasting architecture built on the extended LSTM (xLSTM) design, incorporating either the scalar-memory stabilized LSTM (sLSTM) or the matrix-memory mLSTM variant. This model enhances classical LSTM by adding exponential gating and richer memory dynamics, and combines series decomposition and normalization layers to produce robust forecasts over extended horizons. It is based on this paper: https://arxiv.org/pdf/2407.10240 and https://github.com/muslehal/xLSTMTime """ @classmethod def _pkg(cls): """Package for the model.""" from pytorch_forecasting.models.xlstm._xlstm_pkg import xLSTMTime_pkg return xLSTMTime_pkg def __init__( self, input_size: int, hidden_size: int, output_size: int, xlstm_type: Literal["slstm", "mlstm"] = "slstm", num_layers: int = 1, decomposition_kernel: int = 25, input_projection_size: int | None = None, dropout: float = 0.1, loss: Metric = SMAPE(), **kwargs, ): """ Initialise the model. Parameters ---------- input_size : int Number of input continuous features per time step. hidden_size : int Hidden size of the xLSTM network; also used by batch norm / LSTM internals. output_size : int Number of output features per time step (forecast horizon). xlstm_type : {"slstm", "mlstm"}, default "slstm" Specifies which xLSTM variant to use: - "slstm": stabilized LSTM with scalar memory, - "mlstm": matrix-memory variant for higher capacity and scalability. num_layers : int, default 1 Number of recurrent layers in the sLSTM or mLSTM network. decomposition_kernel : int, default 25 Kernel size for series decomposition into trend and seasonal components. input_projection_size : int, optional If specified, the encoded input (trend + seasonal) is projected to this size before being fed to the xLSTM; otherwise equals hidden_size. dropout : float, default 0.1 Dropout rate applied within the recurrent layers. loss : pytorch_forecasting.metrics.Metric, default SMAPE() Loss (and evaluation metric) used during training. """ if "target" in kwargs: del kwargs["target"] if "target_lags" in kwargs: del kwargs["target_lags"] self.save_hyperparameters() super().__init__(loss=loss, **kwargs) if xlstm_type not in ["slstm", "mlstm"]: raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") self.xlstm_type = xlstm_type self.decomposition = SeriesDecomposition(decomposition_kernel) self.batch_norm = nn.BatchNorm1d(hidden_size) self.input_projection_size = input_projection_size or hidden_size self.input_linear = nn.Linear(input_size * 2, self.input_projection_size) if xlstm_type == "mlstm": self.lstm = mLSTMNetwork( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, output_size=hidden_size, dropout=dropout, ) else: # slstm self.lstm = sLSTMNetwork( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, output_size=hidden_size, dropout=dropout, ) self.output_linear = nn.Linear(hidden_size, output_size) self.instance_norm = nn.InstanceNorm1d(output_size)
[docs] def forward( self, x: dict[str, torch.Tensor], hidden_states: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: """Forward Pass for the model.""" encoder_cont = x["encoder_cont"] batch_size, seq_len, n_features = encoder_cont.shape seasonal, trend = self.decomposition(encoder_cont) x = torch.cat([trend, seasonal], dim=-1) x = self.input_linear(x) x = x.transpose(1, 2) x = self.batch_norm(x) x = x.transpose(1, 2) if hidden_states is None: hidden_states = self.lstm.init_hidden(batch_size, device=x.device) x = x.transpose(0, 1) output, hidden_states = self.lstm(x, *hidden_states) if isinstance(output, tuple): output = output[0] if output.dim() == 2: output = output.unsqueeze(0) output = self.output_linear(output) output = output.transpose(1, 2) output = self.instance_norm(output) output = output.transpose(1, 2) output = output[0, ..., : self.hparams.output_size] output = output.unsqueeze(-1) return self.to_network_output(prediction=output)
[docs] @classmethod def from_dataset(cls, dataset, **kwargs): """ Create model from dataset and set parameters related to covariates. Parameters ---------- dataset: timeseries dataset **kwargs: additional arguments such as hyperparameters for model Returns ------- xLSTMTime """ from pytorch_forecasting.data.encoders import NaNLabelEncoder assert not isinstance( dataset.target_normalizer, NaNLabelEncoder ), "only regression tasks are supported - target must not be categorical" new_kwargs = copy(kwargs) new_kwargs.update( cls.deduce_default_output_parameters(dataset, kwargs, SMAPE()) ) return super().from_dataset(dataset, **kwargs)