Source code for pytorch_forecasting.models.timexer._timexer

"""
Time Series Transformer with eXogenous variables (TimeXer)
---------------------------------------------------------
"""

from copy import copy
from typing import Optional, Union
import warnings as warn

import lightning.pytorch as pl
from lightning.pytorch import LightningModule, Trainer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import (
    MAE,
    MAPE,
    MASE,
    RMSE,
    SMAPE,
    MultiHorizonMetric,
    QuantileLoss,
)
from pytorch_forecasting.metrics.base_metrics import MultiLoss
from pytorch_forecasting.models.base import BaseModelWithCovariates
from pytorch_forecasting.models.timexer.sub_modules import (
    AttentionLayer,
    DataEmbedding_inverted,
    Encoder,
    EncoderLayer,
    EnEmbedding,
    FlattenHead,
    FullAttention,
)


[docs] class TimeXer(BaseModelWithCovariates): """TimeXer model for time series forecasting with exogenous variables.""" @classmethod def _pkg(cls): """Package for the model.""" from pytorch_forecasting.models.timexer._timexer_pkg import TimeXer_pkg return TimeXer_pkg def __init__( self, context_length: int, prediction_length: int, task_name: str = "long_term_forecast", features: str = "MS", enc_in: int = None, hidden_size: int = 256, n_heads: int = 4, e_layers: int = 2, d_ff: int = 1024, dropout: float = 0.2, activation: str = "relu", use_efficient_attention: bool = False, patch_length: int = 16, factor: int = 5, embed_type: str = "fixed", freq: str = "h", output_size: int | list[int] = 1, loss: MultiHorizonMetric = None, learning_rate: float = 1e-3, static_categoricals: list[str] | None = None, static_reals: list[str] | None = None, time_varying_categoricals_encoder: list[str] | None = None, time_varying_categoricals_decoder: list[str] | None = None, time_varying_reals_encoder: list[str] | None = None, time_varying_reals_decoder: list[str] | None = None, x_reals: list[str] | None = None, x_categoricals: list[str] | None = None, embedding_sizes: dict[str, tuple[int, int]] | None = None, embedding_labels: list[str] | None = None, embedding_paddings: list[str] | None = None, categorical_groups: dict[str, list[str]] | None = None, logging_metrics: nn.ModuleList = None, **kwargs, ): """An implementation of the TimeXer model. TimeXer empowers the canonical transformer with the ability to reconcile endogenous and exogenous information without any architectural modifications and achieves consistent state-of-the-art performance across twelve real-world forecasting benchmarks. TimeXer employs patch-level and variate-level representations respectively for endogenous and exogenous variables, with an endogenous global token as a bridge in-between. With this design, TimeXer can jointly capture intra-endogenous temporal dependencies and exogenous-to-endogenous correlations. TimeXer model for time series forecasting with exogenous variables. Parameters ---------- context_length (int): Length of input sequence used for making predictions. prediction_length (int): Number of future time steps to predict. task_name (str, optional): Type of forecasting task, either 'long_term_forecast' or 'short_term_forecast', which corresponds to forecasting scenarios implied by the task names. features (str, optional): Type of features used in the model ('MS' for multivariate forecating with single target, 'M' for multivariate forecasting with multiple targets and 'S' for univariate forecasting). enc_in (int, optional): Number of input variables for encoder. hidden_size (int, optional): Dimension of model embeddings and hidden representations. n_heads (int, optional): Number of attention heads in multi-head attention layers. e_layers (int, optional): Number of encoder layers with dual attention mechanism. d_ff (int, optional): Dimension of feedforward network in transformer layers dropout (float, optional): Dropout rate applied throughout the network for regularization. activation (str, optional): Activation function used in feedforward networks ('relu' or 'gelu'). use_efficient_attention (bool, optional): If set to True, will use PyTorch's native, optimized Scaled Dot Product Attention implementation which can reduce computation time and memory consumption for longer sequences. PyTorch automatically selects the optimal backend (FlashAttention-2, Memory-Efficient Attention, or their own C++ implementation) based on user's input properties, hardware capabilities, and build configuration. patch_length (int, optional): Length of each non-overlapping patch for endogenous variable tokenization. use_norm (bool, optional): Whether to apply normalization to input data. Do not change, as it a setting controlled by the pytorch-forecasting API factor: Scaling factor for attention scores. embed_type: Type of time feature embedding ('timeF' for time-based features) freq: Frequency of the time series data('h' for hourly,'d' for daily, etc.). static_categoricals (list[str]): names of static categorical variables static_reals (list[str]): names of static continuous variables time_varying_categoricals_encoder (list[str]): names of categorical variables for encoder time_varying_categoricals_decoder (list[str]): names of categorical variables for decoder time_varying_reals_encoder (list[str]): names of continuous variables for encoder time_varying_reals_decoder (list[str]): names of continuous variables for decoder x_reals (list[str]): order of continuous variables in tensor passed to forward function x_categoricals (list[str]): order of categorical variables in tensor passed to forward function embedding_sizes (dict[str, tuple[int, int]]): dictionary mapping categorical variables to tuple of integers where the first integer denotes the number of categorical classes and the second the embedding size embedding_labels (dict[str, list[str]]): dictionary mapping (string) indices to list of categorical labels embedding_paddings (list[str]): names of categorical variables for which label 0 is always mapped to an embedding vector filled with zeros categorical_groups (dict[str, list[str]]): dictionary of categorical variables that are grouped together and can also take multiple values simultaneously (e.g. holiday during octoberfest). They should be implemented as bag of embeddings. logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]). **kwargs: additional arguments to :py:class:`~BaseModel`. """ if static_categoricals is None: static_categoricals = [] if static_reals is None: static_reals = [] if time_varying_categoricals_encoder is None: time_varying_categoricals_encoder = [] if time_varying_categoricals_decoder is None: time_varying_categoricals_decoder = [] if categorical_groups is None: categorical_groups = {} if time_varying_reals_encoder is None: time_varying_reals_encoder = [] if time_varying_reals_decoder is None: time_varying_reals_decoder = [] if embedding_sizes is None: embedding_sizes = {} if embedding_paddings is None: embedding_paddings = [] if embedding_labels is None: embedding_labels = {} if x_reals is None: x_reals = [] if x_categoricals is None: x_categoricals = [] if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]) if loss is None: if features == "M": loss = MultiLoss([MAE()] * len(self.target_positions)) else: loss = MAE() self.save_hyperparameters(ignore=["loss", "logging_metrics"]) # loss is a standalone module and is stored separately. super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) if self.hparams.context_length < self.hparams.patch_length: raise ValueError( f"context_length ({context_length}) must be greater than or equal to" f" patch_length ({patch_length}). Model cannot create patches larger" " than the sequence length." ) if self.hparams.context_length % self.hparams.patch_length != 0: warn.warn( f"In the input sequence, the context_length ({context_length}) is not a" f" multiple of the patch_length ({patch_length}). This may lead to some" "patches being ignored during training." ) self.patch_num = max( 1, int(self.hparams.context_length // self.hparams.patch_length) ) self.n_target_vars = len(self.target_positions) self.enc_in = enc_in if enc_in is None: self.enc_in = len(self.reals) # NOTE: assume point prediction as default here, # with single median quantile being the point prediction. # hence self.n_quantiles = 1 for point predictions. self.n_quantiles = 1 # set n_quantiles to the length of the quantiles list passed # into the "quantiles" parameter when QuantileLoss is used. if isinstance(loss, QuantileLoss): self.n_quantiles = len(loss.quantiles) if hidden_size % n_heads != 0: raise ValueError( f"hidden_size ({hidden_size}) must be divisible by n_heads ({n_heads}) " f"for the multi-head attention mechanism to work properly." ) self.en_embedding = EnEmbedding( self.n_target_vars, self.hparams.hidden_size, self.hparams.patch_length, self.hparams.dropout, ) self.ex_embedding = DataEmbedding_inverted( self.hparams.context_length, self.hparams.hidden_size, self.hparams.embed_type, self.hparams.freq, self.hparams.dropout, ) if e_layers <= 0: raise ValueError(f"e_layers ({e_layers}) must be positive.") elif e_layers > 12: warn.warn( f"e_layers ({e_layers}) is quite high. This might lead to overfitting " f"and high computational cost. Consider using 2-6 layers.", UserWarning, ) self.encoder = Encoder( [ EncoderLayer( AttentionLayer( FullAttention( False, self.hparams.factor, attention_dropout=self.hparams.dropout, output_attention=False, use_efficient_attention=self.hparams.use_efficient_attention, ), self.hparams.hidden_size, self.hparams.n_heads, ), AttentionLayer( FullAttention( False, self.hparams.factor, attention_dropout=self.hparams.dropout, output_attention=False, use_efficient_attention=self.hparams.use_efficient_attention, ), self.hparams.hidden_size, self.hparams.n_heads, ), self.hparams.hidden_size, self.hparams.d_ff, dropout=self.hparams.dropout, activation=self.hparams.activation, ) for l in range(self.hparams.e_layers) ], norm_layer=torch.nn.LayerNorm(self.hparams.hidden_size), ) self.head_nf = self.hparams.hidden_size * (self.patch_num + 1) self.head = FlattenHead( self.enc_in, self.head_nf, self.hparams.prediction_length, head_dropout=self.hparams.dropout, n_quantiles=self.n_quantiles, )
[docs] @classmethod def from_dataset( cls, dataset: TimeSeriesDataSet, allowed_encoder_known_variable_names: list[str] = None, **kwargs, ): """ Create model from dataset and set parameters related to covariates. Args: dataset: timeseries dataset allowed_encoder_known_variable_names: list of known variables that are allowed in encoder, defaults to all **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) Returns: TimeXer """ # noqa: E501 new_kwargs = copy(kwargs) new_kwargs.update( { "context_length": dataset.max_encoder_length, "prediction_length": dataset.max_prediction_length, } ) new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, MAE())) return super().from_dataset( dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs, )
def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forecast for univariate or multivariate with single target (MS) case. Args: x: Dictionary containing entries for encoder_cat, encoder_cont """ encoder_cont = x["encoder_cont"] encoder_time_idx = x.get("encoder_time_idx", None) target_pos = self.target_positions # masking to ignore the target variable mask = torch.ones(encoder_cont.shape[-1], dtype=torch.bool) mask[target_pos] = False exog_data = encoder_cont[..., mask] en_embed, n_vars = self.en_embedding( encoder_cont[:, :, target_pos[-1]].unsqueeze(-1).permute(0, 2, 1) ) ex_embed = self.ex_embedding(exog_data, encoder_time_idx) enc_out = self.encoder(en_embed, ex_embed) enc_out = torch.reshape( enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) ) enc_out = enc_out.permute(0, 1, 3, 2) dec_out = self.head(enc_out) dec_out = dec_out.permute(0, 2, 1, 3) return dec_out def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forecast for multivariate with multiple targets (M) case. Args: x: Dictionary containing entries for encoder_cat, encoder_cont Returns: Dictionary with predictions """ encoder_cont = x["encoder_cont"] encoder_time_idx = x.get("encoder_time_idx", None) target_pos = self.target_positions encoder_target = encoder_cont[..., target_pos] en_embed, n_vars = self.en_embedding(encoder_target.permute(0, 2, 1)) # use masking to ignore the target variable in encoder_cont under ex_embed. mask = torch.ones( encoder_cont.shape[-1], dtype=torch.bool, device=encoder_cont.device ) mask[target_pos] = False exog_data = encoder_cont[..., mask] ex_embed = self.ex_embedding(exog_data, encoder_time_idx) # batch_size x sequence_length x hidden_size enc_out = self.encoder(en_embed, ex_embed) enc_out = torch.reshape( enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]) ) # batch_size x n_vars x sequence_length x hidden_size enc_out = enc_out.permute(0, 1, 3, 2) dec_out = self.head(enc_out) dec_out = dec_out.permute(0, 2, 1, 3) return dec_out @property def decoder_covariate_size(self) -> int: """Decoder covariates size. Returns: int: size of time-dependent covariates used by the decoder """ return len( set(self.hparams.time_varying_reals_decoder) - set(self.target_names) ) + sum( self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder ) @property def encoder_covariate_size(self) -> int: """Encoder covariate size. Returns: int: size of time-dependent covariates used by the encoder """ return len( set(self.hparams.time_varying_reals_encoder) - set(self.target_names) ) + sum( self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder ) @property def static_size(self) -> int: """Static covariate size. Returns: int: size of static covariates """ return len(self.hparams.static_reals) + sum( self.embeddings.output_size[name] for name in self.hparams.static_categoricals )
[docs] def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Args: x: Dictionary containing model inputs Returns: Dictionary with model outputs """ if ( self.hparams.task_name == "long_term_forecast" or self.hparams.task_name == "short_term_forecast" ): # noqa: E501 if self.hparams.features == "M": out = self._forecast_multi(x) else: out = self._forecast(x) prediction = out[:, : self.hparams.prediction_length, :] target_positions = self.target_positions # note: prediction.size(2) is the number of target variables i.e n_targets target_indices = range(prediction.size(2)) if prediction.size(2) != len(target_positions): prediction = prediction[:, :, : len(target_positions)] # output format is (batch_size, prediction_length, n_quantiles) # in case of quantile loss, the output n_quantiles = self.n_quantiles # which is the length of a list of float. In case of MAE, MSE, etc. # n_quantiles = 1 and it mimics the behavior of a point prediction. # for multi-target forecasting, the output is a list of tensors. if len(target_positions) == 1: prediction = prediction[..., 0, :] else: prediction = [prediction[..., i, :] for i in target_indices] prediction = self.transform_output( prediction=prediction, target_scale=x["target_scale"] ) return self.to_network_output(prediction=prediction) else: return None