TemporalFusionTransformer#

class pytorch_forecasting.models.temporal_fusion_transformer._tft.TemporalFusionTransformer(hidden_size: int = 16, lstm_layers: int = 1, dropout: float = 0.1, output_size: int | list[int] = 7, loss: MultiHorizonMetric = None, attention_head_size: int = 4, max_encoder_length: int = 10, static_categoricals: list[str] | None = None, static_reals: list[str] | None = None, time_varying_categoricals_encoder: list[str] | None = None, time_varying_categoricals_decoder: list[str] | None = None, categorical_groups: dict | list[str] | None = None, time_varying_reals_encoder: list[str] | None = None, time_varying_reals_decoder: list[str] | None = None, x_reals: list[str] | None = None, x_categoricals: list[str] | None = None, hidden_continuous_size: int = 8, hidden_continuous_sizes: dict[str, int] | None = None, embedding_sizes: dict[str, tuple[int, int]] | None = None, embedding_paddings: list[str] | None = None, embedding_labels: dict[str, ndarray] | None = None, learning_rate: float = 0.001, log_interval: int | float = -1, log_val_interval: int | float = None, log_gradient_flow: bool = False, reduce_on_plateau_patience: int = 1000, monotone_constaints: dict[str, int] | None = None, share_single_variable_networks: bool = False, causal_attention: bool = True, logging_metrics: ModuleList = None, mask_bias: float = -1000000000.0, **kwargs)[source]#

Bases: BaseModelWithCovariates

Temporal Fusion Transformer for forecasting timeseries.

Initialize via from_dataset() method if possible.

Implementation of Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting.

Enhancements compared to the original implementation:

  • static variables can be continuous

  • multiple categorical variables can be summarized with an EmbeddingBag

  • variable encoder and decoder length by sample

  • categorical embeddings are not transformed by variable selection network (because it is a redundant operation)

  • variable dimension in variable selection network are scaled up via linear interpolation to reduce number of parameters

  • non-linear variable processing in variable selection network can be shared among decoder and encoder (not shared by default)

  • capabilities added through base model such as monotone constraints

Tune its hyperparameters with optimize_hyperparameters().

Parameters:
  • hidden_size (int, default=16) – hidden size of network which is its main hyperparameter. Can range from 8 to 512.

  • lstm_layers (int, default=1) – number of LSTM layers (2 is mostly optimal)

  • dropout (float, default=0.1) – dropout rate

  • output_size (int or list of int, default=7) – number of outputs (e.g. number of quantiles for QuantileLoss and one target or list of output sizes).

  • loss (MultiHorizonMetric, default=QuantileLoss()) – loss function taking prediction and targets

  • attention_head_size (int, default=4) – number of attention heads (4 is a good default)

  • max_encoder_length (int, default=10) – length to encode, can be far longer than the decoder length but does not have to be

  • static_categoricals (names of static categorical variables)

  • static_reals (names of static continuous variables)

  • time_varying_categoricals_encoder (names of categorical variables for encoder)

  • time_varying_categoricals_decoder (names of categorical variables for decoder)

  • time_varying_reals_encoder (names of continuous variables for encoder)

  • time_varying_reals_decoder (names of continuous variables for decoder)

  • categorical_groups (dictionary where values) – are list of categorical variables that are forming together a new categorical variable which is the key in the dictionary

  • x_reals (order of continuous variables in tensor passed to forward function)

  • x_categoricals (order of categorical variables in tensor passed to forward function)

  • hidden_continuous_size (default for hidden size for processing continous variables (similar to categorical) – embedding size)

  • hidden_continuous_sizes (dictionary mapping continuous input indices to sizes for variable selection) – (fallback to hidden_continuous_size if index is not in dictionary)

  • embedding_sizes (dictionary mapping (string) indices to tuple of number of categorical classes and) – embedding size

  • embedding_paddings (list of indices for embeddings which transform the zero's embedding to a zero vector)

  • embedding_labels (dictionary mapping (string) indices to list of categorical labels)

  • learning_rate (learning rate)

  • log_interval (log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0) – , will log multiple entries per batch. Defaults to -1.

  • log_val_interval (frequency with which to log validation set metrics, defaults to log_interval)

  • log_gradient_flow (if to log gradient flow, this takes time and should be only done to diagnose training) – failures

  • (int) (reduce_on_plateau_patience)

  • (Dict[str (monotone_constaints) – variables mapping position (e.g. "0" for first position) to constraint (-1 for negative and +1 for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.

  • int]) (dictionary of monotonicity constraints for continuous decoder) – variables mapping position (e.g. "0" for first position) to constraint (-1 for negative and +1 for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.

  • (bool) (causal_attention) – decoder. Defaults to False.

  • (bool) – predictions. Defaults to True.

  • (nn.ModuleList[LightningMetric]) (logging_metrics) – Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).

  • mask_bias (float, optional) – Bias for the mask in ScaledDotProductAttention.forward, by default -1e9. Set to -float(“inf”) to allow mixed precision training.

  • **kwargs (additional arguments to BaseModel.)

BaseModel for timeseries forecasting from which to inherit from

Parameters:
  • log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.

  • log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.

  • learning_rate (float, optional) – Learning rate. Defaults to 1e-3.

  • log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.

  • loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().

  • logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].

  • reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000

  • reduce_on_plateau_reduction (float) – reduction in learning rate when encountering plateau. Defaults to 2.0.

  • reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateua learning rate scheduler. Defaults to 1e-5

  • weight_decay (float) – weight decay. Defaults to 0.0.

  • optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.

  • monotone_constraints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g. "0" for first position) to constraint (-1 for negative and +1 for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.

  • output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to lambda out: out["prediction"].

  • optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in torch.optim or pytorch_optimizer. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay). Defaults to “adam”.

Methods

create_log(x, y, out, batch_idx, **kwargs)

Create the log used in the training and validation step.

expand_static_context(context, timesteps)

add time dimension to static context

forward(x)

input dimensions: n_samples x time x variables

from_dataset(dataset[, ...])

Create model from dataset.

get_attention_mask(encoder_lengths, ...)

Returns causal mask to apply for self-attention layer.

interpret_output(out[, reduction, ...])

interpret output of model

log_embeddings()

Log embeddings to tensorboard

log_interpretation(outputs)

Log interpretation metrics to tensorboard.

on_epoch_end(outputs)

run at epoch end for training or validation

on_fit_end()

Called at the very end of fit.

plot_interpretation(interpretation)

Make figures that interpret model.

plot_prediction(x, out, idx[, ...])

Plot actuals vs prediction and attention

create_log(x, y, out, batch_idx, **kwargs)[source]#

Create the log used in the training and validation step.

Parameters:
  • x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader

  • y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader

  • out (Dict[str, torch.Tensor]) – output of the network

  • batch_idx (int) – batch number

  • prediction_kwargs (Dict[str, Any], optional) – arguments to pass to to_prediction(). Defaults to {}.

  • quantiles_kwargs (Dict[str, Any], optional) – to_quantiles(). Defaults to {}.

Returns:

log dictionary to be returned by training and validation steps

Return type:

Dict[str, Any]

expand_static_context(context, timesteps)[source]#

add time dimension to static context

forward(x: dict[str, Tensor]) dict[str, Tensor][source]#

input dimensions: n_samples x time x variables

classmethod from_dataset(dataset: TimeSeriesDataSet, allowed_encoder_known_variable_names: list[str] = None, **kwargs)[source]#

Create model from dataset.

Parameters:
  • dataset – timeseries dataset

  • allowed_encoder_known_variable_names – List of known variables that are allowed in encoder, defaults to all

  • **kwargs – additional arguments such as hyperparameters for model (see __init__())

Returns:

TemporalFusionTransformer

get_attention_mask(encoder_lengths: LongTensor, decoder_lengths: LongTensor)[source]#

Returns causal mask to apply for self-attention layer.

interpret_output(out: dict[str, Tensor], reduction: str = 'none', attention_prediction_horizon: int = 0) dict[str, Tensor][source]#

interpret output of model

Parameters:
  • out – output as produced by forward()

  • reduction – “none” for no averaging over batches, “sum” for summing attentions, “mean” for normalizing by encode lengths

  • attention_prediction_horizon – which prediction horizon to use for attention

Returns:

interpretations that can be plotted with plot_interpretation()

log_embeddings()[source]#

Log embeddings to tensorboard

log_interpretation(outputs)[source]#

Log interpretation metrics to tensorboard.

on_epoch_end(outputs)[source]#

run at epoch end for training or validation

on_fit_end()[source]#

Called at the very end of fit.

If on DDP it is called on every process

plot_interpretation(interpretation: dict[str, Tensor])[source]#

Make figures that interpret model.

  • Attention

  • Variable selection weights / importances

Parameters:

interpretation – as obtained from interpret_output()

Returns:

dictionary of matplotlib figures

plot_prediction(x: dict[str, Tensor], out: dict[str, Tensor], idx: int, plot_attention: bool = True, add_loss_to_title: bool = False, show_future_observed: bool = True, ax=None, **kwargs)[source]#

Plot actuals vs prediction and attention

Parameters:
  • x (Dict[str, torch.Tensor]) – network input

  • out (Dict[str, torch.Tensor]) – network output

  • idx (int) – sample index

  • plot_attention – if to plot attention on secondary axis

  • add_loss_to_title – if to add loss to title. Default to False.

  • show_future_observed – if to show actuals for future. Defaults to True.

  • ax – matplotlib axes to plot on

Returns:

matplotlib figure

Return type:

plt.Figure