class pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer(hidden_size: int = 16, lstm_layers: int = 1, dropout: float = 0.1, output_size: int | List[int] = 7, loss: MultiHorizonMetric | None = None, attention_head_size: int = 4, max_encoder_length: int = 10, static_categoricals: List[str] = [], static_reals: List[str] = [], time_varying_categoricals_encoder: List[str] = [], time_varying_categoricals_decoder: List[str] = [], categorical_groups: Dict[str, List[str]] = {}, time_varying_reals_encoder: List[str] = [], time_varying_reals_decoder: List[str] = [], x_reals: List[str] = [], x_categoricals: List[str] = [], hidden_continuous_size: int = 8, hidden_continuous_sizes: Dict[str, int] = {}, embedding_sizes: Dict[str, Tuple[int, int]] = {}, embedding_paddings: List[str] = [], embedding_labels: Dict[str, ndarray] = {}, learning_rate: float = 0.001, log_interval: int | float = -1, log_val_interval: float | int | None = None, log_gradient_flow: bool = False, reduce_on_plateau_patience: int = 1000, monotone_constaints: Dict[str, int] = {}, share_single_variable_networks: bool = False, causal_attention: bool = True, logging_metrics: ModuleList | None = None, **kwargs)[source]#

Bases: BaseModelWithCovariates

Temporal Fusion Transformer for forecasting timeseries - use its from_dataset() method if possible.

Implementation of the article Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting. The network outperforms DeepAR by Amazon by 36-69% in benchmarks.

Enhancements compared to the original implementation (apart from capabilities added through base model such as monotone constraints):

  • static variables can be continuous

  • multiple categorical variables can be summarized with an EmbeddingBag

  • variable encoder and decoder length by sample

  • categorical embeddings are not transformed by variable selection network (because it is a redundant operation)

  • variable dimension in variable selection network are scaled up via linear interpolation to reduce number of parameters

  • non-linear variable processing in variable selection network can be shared among decoder and encoder (not shared by default)

Tune its hyperparameters with optimize_hyperparameters().

  • hidden_size – hidden size of network which is its main hyperparameter and can range from 8 to 512

  • lstm_layers – number of LSTM layers (2 is mostly optimal)

  • dropout – dropout rate

  • output_size – number of outputs (e.g. number of quantiles for QuantileLoss and one target or list of output sizes).

  • loss – loss function taking prediction and targets

  • attention_head_size – number of attention heads (4 is a good default)

  • max_encoder_length – length to encode (can be far longer than the decoder length but does not have to be)

  • static_categoricals – names of static categorical variables

  • static_reals – names of static continuous variables

  • time_varying_categoricals_encoder – names of categorical variables for encoder

  • time_varying_categoricals_decoder – names of categorical variables for decoder

  • time_varying_reals_encoder – names of continuous variables for encoder

  • time_varying_reals_decoder – names of continuous variables for decoder

  • categorical_groups – dictionary where values are list of categorical variables that are forming together a new categorical variable which is the key in the dictionary

  • x_reals – order of continuous variables in tensor passed to forward function

  • x_categoricals – order of categorical variables in tensor passed to forward function

  • hidden_continuous_size – default for hidden size for processing continous variables (similar to categorical embedding size)

  • hidden_continuous_sizes – dictionary mapping continuous input indices to sizes for variable selection (fallback to hidden_continuous_size if index is not in dictionary)

  • embedding_sizes – dictionary mapping (string) indices to tuple of number of categorical classes and embedding size

  • embedding_paddings – list of indices for embeddings which transform the zero’s embedding to a zero vector

  • embedding_labels – dictionary mapping (string) indices to list of categorical labels

  • learning_rate – learning rate

  • log_interval – log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 , will log multiple entries per batch. Defaults to -1.

  • log_val_interval – frequency with which to log validation set metrics, defaults to log_interval

  • log_gradient_flow – if to log gradient flow, this takes time and should be only done to diagnose training failures

  • reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10

  • monotone_constaints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g. "0" for first position) to constraint (-1 for negative and +1 for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.

  • share_single_variable_networks (bool) – if to share the single variable networks between the encoder and decoder. Defaults to False.

  • causal_attention (bool) – If to attend only at previous timesteps in the decoder or also include future predictions. Defaults to True.

  • logging_metrics (nn.ModuleList[LightningMetric]) – list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).

  • **kwargs – additional arguments to BaseModel.


create_log(x, y, out, batch_idx, **kwargs)

Create the log used in the training and validation step.

expand_static_context(context, timesteps)

add time dimension to static context


input dimensions: n_samples x time x variables

from_dataset(dataset[, ...])

Create model from dataset.

get_attention_mask(encoder_lengths, ...)

Returns causal mask to apply for self-attention layer.

interpret_output(out[, reduction, ...])

interpret output of model


Log embeddings to tensorboard


Log interpretation metrics to tensorboard.


run at epoch end for training or validation


Called at the very end of fit.


Make figures that interpret model.

plot_prediction(x, out, idx[, ...])

Plot actuals vs prediction and attention

create_log(x, y, out, batch_idx, **kwargs)[source]#

Create the log used in the training and validation step.

  • x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader

  • y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader

  • out (Dict[str, torch.Tensor]) – output of the network

  • batch_idx (int) – batch number

  • prediction_kwargs (Dict[str, Any], optional) – arguments to pass to to_prediction(). Defaults to {}.

  • quantiles_kwargs (Dict[str, Any], optional) – to_quantiles(). Defaults to {}.


log dictionary to be returned by training and validation steps

Return type:

Dict[str, Any]

expand_static_context(context, timesteps)[source]#

add time dimension to static context

forward(x: Dict[str, Tensor]) Dict[str, Tensor][source]#

input dimensions: n_samples x time x variables

classmethod from_dataset(dataset: TimeSeriesDataSet, allowed_encoder_known_variable_names: List[str] | None = None, **kwargs)[source]#

Create model from dataset.

  • dataset – timeseries dataset

  • allowed_encoder_known_variable_names – List of known variables that are allowed in encoder, defaults to all

  • **kwargs – additional arguments such as hyperparameters for model (see __init__())



get_attention_mask(encoder_lengths: LongTensor, decoder_lengths: LongTensor)[source]#

Returns causal mask to apply for self-attention layer.

interpret_output(out: Dict[str, Tensor], reduction: str = 'none', attention_prediction_horizon: int = 0) Dict[str, Tensor][source]#

interpret output of model

  • out – output as produced by forward()

  • reduction – “none” for no averaging over batches, “sum” for summing attentions, “mean” for normalizing by encode lengths

  • attention_prediction_horizon – which prediction horizon to use for attention


interpretations that can be plotted with plot_interpretation()


Log embeddings to tensorboard


Log interpretation metrics to tensorboard.


run at epoch end for training or validation


Called at the very end of fit.

If on DDP it is called on every process

plot_interpretation(interpretation: Dict[str, Tensor]) Dict[str, Figure][source]#

Make figures that interpret model.

  • Attention

  • Variable selection weights / importances


interpretation – as obtained from interpret_output()


dictionary of matplotlib figures

plot_prediction(x: Dict[str, Tensor], out: Dict[str, Tensor], idx: int, plot_attention: bool = True, add_loss_to_title: bool = False, show_future_observed: bool = True, ax=None, **kwargs) Figure[source]#

Plot actuals vs prediction and attention

  • x (Dict[str, torch.Tensor]) – network input

  • out (Dict[str, torch.Tensor]) – network output

  • idx (int) – sample index

  • plot_attention – if to plot attention on secondary axis

  • add_loss_to_title – if to add loss to title. Default to False.

  • show_future_observed – if to show actuals for future. Defaults to True.

  • ax – matplotlib axes to plot on


matplotlib figure

Return type: