class pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer(hidden_size: int = 16, lstm_layers: int = 1, dropout: float = 0.1, output_size: Union[int, List[int]] = 7, loss: Optional[pytorch_forecasting.metrics.base_metrics.MultiHorizonMetric] = None, attention_head_size: int = 4, max_encoder_length: int = 10, static_categoricals: List[str] = [], static_reals: List[str] = [], time_varying_categoricals_encoder: List[str] = [], time_varying_categoricals_decoder: List[str] = [], categorical_groups: Dict[str, List[str]] = {}, time_varying_reals_encoder: List[str] = [], time_varying_reals_decoder: List[str] = [], x_reals: List[str] = [], x_categoricals: List[str] = [], hidden_continuous_size: int = 8, hidden_continuous_sizes: Dict[str, int] = {}, embedding_sizes: Dict[str, Tuple[int, int]] = {}, embedding_paddings: List[str] = [], embedding_labels: Dict[str, numpy.ndarray] = {}, learning_rate: float = 0.001, log_interval: Union[int, float] = - 1, log_val_interval: Optional[Union[float, int]] = None, log_gradient_flow: bool = False, reduce_on_plateau_patience: int = 1000, monotone_constaints: Dict[str, int] = {}, share_single_variable_networks: bool = False, causal_attention: bool = True, logging_metrics: Optional[torch.nn.modules.container.ModuleList] = None, **kwargs)[source]#

Bases: pytorch_forecasting.models.base_model.BaseModelWithCovariates

Temporal Fusion Transformer for forecasting timeseries - use its from_dataset() method if possible.

Implementation of the article Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting. The network outperforms DeepAR by Amazon by 36-69% in benchmarks.

Enhancements compared to the original implementation (apart from capabilities added through base model such as monotone constraints):

  • static variables can be continuous

  • multiple categorical variables can be summarized with an EmbeddingBag

  • variable encoder and decoder length by sample

  • categorical embeddings are not transformed by variable selection network (because it is a redundant operation)

  • variable dimension in variable selection network are scaled up via linear interpolation to reduce number of parameters

  • non-linear variable processing in variable selection network can be shared among decoder and encoder (not shared by default)

Tune its hyperparameters with optimize_hyperparameters().

  • hidden_size – hidden size of network which is its main hyperparameter and can range from 8 to 512

  • lstm_layers – number of LSTM layers (2 is mostly optimal)

  • dropout – dropout rate

  • output_size – number of outputs (e.g. number of quantiles for QuantileLoss and one target or list of output sizes).

  • loss – loss function taking prediction and targets

  • attention_head_size – number of attention heads (4 is a good default)

  • max_encoder_length – length to encode (can be far longer than the decoder length but does not have to be)

  • static_categoricals – names of static categorical variables

  • static_reals – names of static continuous variables

  • time_varying_categoricals_encoder – names of categorical variables for encoder

  • time_varying_categoricals_decoder – names of categorical variables for decoder

  • time_varying_reals_encoder – names of continuous variables for encoder

  • time_varying_reals_decoder – names of continuous variables for decoder

  • categorical_groups – dictionary where values are list of categorical variables that are forming together a new categorical variable which is the key in the dictionary

  • x_reals – order of continuous variables in tensor passed to forward function

  • x_categoricals – order of categorical variables in tensor passed to forward function

  • hidden_continuous_size – default for hidden size for processing continous variables (similar to categorical embedding size)

  • hidden_continuous_sizes – dictionary mapping continuous input indices to sizes for variable selection (fallback to hidden_continuous_size if index is not in dictionary)

  • embedding_sizes – dictionary mapping (string) indices to tuple of number of categorical classes and embedding size

  • embedding_paddings – list of indices for embeddings which transform the zero’s embedding to a zero vector

  • embedding_labels – dictionary mapping (string) indices to list of categorical labels

  • learning_rate – learning rate

  • log_interval – log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 , will log multiple entries per batch. Defaults to -1.

  • log_val_interval – frequency with which to log validation set metrics, defaults to log_interval

  • log_gradient_flow – if to log gradient flow, this takes time and should be only done to diagnose training failures

  • reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10

  • monotone_constaints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g. "0" for first position) to constraint (-1 for negative and +1 for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.

  • share_single_variable_networks (bool) – if to share the single variable networks between the encoder and decoder. Defaults to False.

  • causal_attention (bool) – If to attend only at previous timesteps in the decoder or also include future predictions. Defaults to True.

  • logging_metrics (nn.ModuleList[LightningMetric]) – list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).

  • **kwargs – additional arguments to BaseModel.


create_log(x, y, out, batch_idx, **kwargs)

Create the log used in the training and validation step.


run at epoch end for training or validation

expand_static_context(context, timesteps)

add time dimension to static context


input dimensions: n_samples x time x variables

from_dataset(dataset[, ...])

Create model from dataset.

get_attention_mask(encoder_lengths, ...)

Returns causal mask to apply for self-attention layer.

interpret_output(out[, reduction, ...])

interpret output of model


Log embeddings to tensorboard


Log interpretation metrics to tensorboard.


Called at the very end of fit.


Make figures that interpret model.

plot_prediction(x, out, idx[, ...])

Plot actuals vs prediction and attention

create_log(x, y, out, batch_idx, **kwargs)[source]#

Create the log used in the training and validation step.

  • x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader

  • y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader

  • out (Dict[str, torch.Tensor]) – output of the network

  • batch_idx (int) – batch number

  • prediction_kwargs (Dict[str, Any], optional) – arguments to pass to to_prediction(). Defaults to {}.

  • quantiles_kwargs (Dict[str, Any], optional) – to_quantiles(). Defaults to {}.


log dictionary to be returned by training and validation steps

Return type

Dict[str, Any]


run at epoch end for training or validation

expand_static_context(context, timesteps)[source]#

add time dimension to static context

forward(x: Dict[str, torch.Tensor]) Dict[str, torch.Tensor][source]#

input dimensions: n_samples x time x variables

classmethod from_dataset(dataset: pytorch_forecasting.data.timeseries.TimeSeriesDataSet, allowed_encoder_known_variable_names: Optional[List[str]] = None, **kwargs)[source]#

Create model from dataset.

  • dataset – timeseries dataset

  • allowed_encoder_known_variable_names – List of known variables that are allowed in encoder, defaults to all

  • **kwargs – additional arguments such as hyperparameters for model (see __init__())



get_attention_mask(encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor)[source]#

Returns causal mask to apply for self-attention layer.

interpret_output(out: Dict[str, torch.Tensor], reduction: str = 'none', attention_prediction_horizon: int = 0) Dict[str, torch.Tensor][source]#

interpret output of model

  • out – output as produced by forward()

  • reduction – “none” for no averaging over batches, “sum” for summing attentions, “mean” for normalizing by encode lengths

  • attention_prediction_horizon – which prediction horizon to use for attention


interpretations that can be plotted with plot_interpretation()


Log embeddings to tensorboard


Log interpretation metrics to tensorboard.


Called at the very end of fit.

If on DDP it is called on every process

plot_interpretation(interpretation: Dict[str, torch.Tensor]) Dict[str, matplotlib.figure.Figure][source]#

Make figures that interpret model.

  • Attention

  • Variable selection weights / importances


interpretation – as obtained from interpret_output()


dictionary of matplotlib figures

plot_prediction(x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], idx: int, plot_attention: bool = True, add_loss_to_title: bool = False, show_future_observed: bool = True, ax=None, **kwargs) matplotlib.figure.Figure[source]#

Plot actuals vs prediction and attention

  • x (Dict[str, torch.Tensor]) – network input

  • out (Dict[str, torch.Tensor]) – network output

  • idx (int) – sample index

  • plot_attention – if to plot attention on secondary axis

  • add_loss_to_title – if to add loss to title. Default to False.

  • show_future_observed – if to show actuals for future. Defaults to True.

  • ax – matplotlib axes to plot on


matplotlib figure

Return type