TemporalFusionTransformer#
- class pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer(hidden_size: int = 16, lstm_layers: int = 1, dropout: float = 0.1, output_size: int | List[int] = 7, loss: MultiHorizonMetric = None, attention_head_size: int = 4, max_encoder_length: int = 10, static_categoricals: List[str] | None = None, static_reals: List[str] | None = None, time_varying_categoricals_encoder: List[str] | None = None, time_varying_categoricals_decoder: List[str] | None = None, categorical_groups: Dict | List[str] | None = None, time_varying_reals_encoder: List[str] | None = None, time_varying_reals_decoder: List[str] | None = None, x_reals: List[str] | None = None, x_categoricals: List[str] | None = None, hidden_continuous_size: int = 8, hidden_continuous_sizes: Dict[str, int] | None = None, embedding_sizes: Dict[str, Tuple[int, int]] | None = None, embedding_paddings: List[str] | None = None, embedding_labels: Dict[str, ndarray] | None = None, learning_rate: float = 0.001, log_interval: int | float = -1, log_val_interval: int | float = None, log_gradient_flow: bool = False, reduce_on_plateau_patience: int = 1000, monotone_constaints: Dict[str, int] | None = None, share_single_variable_networks: bool = False, causal_attention: bool = True, logging_metrics: ModuleList = None, **kwargs)[source]#
Bases:
BaseModelWithCovariates
Temporal Fusion Transformer for forecasting timeseries - use its
from_dataset()
method if possible.Implementation of the article Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting. The network outperforms DeepAR by Amazon by 36-69% in benchmarks.
Enhancements compared to the original implementation (apart from capabilities added through base model such as monotone constraints):
static variables can be continuous
multiple categorical variables can be summarized with an EmbeddingBag
variable encoder and decoder length by sample
categorical embeddings are not transformed by variable selection network (because it is a redundant operation)
variable dimension in variable selection network are scaled up via linear interpolation to reduce number of parameters
non-linear variable processing in variable selection network can be shared among decoder and encoder (not shared by default)
Tune its hyperparameters with
optimize_hyperparameters()
.- Parameters:
hidden_size – hidden size of network which is its main hyperparameter and can range from 8 to 512
lstm_layers – number of LSTM layers (2 is mostly optimal)
dropout – dropout rate
output_size – number of outputs (e.g. number of quantiles for QuantileLoss and one target or list of output sizes).
loss – loss function taking prediction and targets
attention_head_size – number of attention heads (4 is a good default)
max_encoder_length – length to encode (can be far longer than the decoder length but does not have to be)
static_categoricals – names of static categorical variables
static_reals – names of static continuous variables
time_varying_categoricals_encoder – names of categorical variables for encoder
time_varying_categoricals_decoder – names of categorical variables for decoder
time_varying_reals_encoder – names of continuous variables for encoder
time_varying_reals_decoder – names of continuous variables for decoder
categorical_groups – dictionary where values are list of categorical variables that are forming together a new categorical variable which is the key in the dictionary
x_reals – order of continuous variables in tensor passed to forward function
x_categoricals – order of categorical variables in tensor passed to forward function
hidden_continuous_size – default for hidden size for processing continous variables (similar to categorical embedding size)
hidden_continuous_sizes – dictionary mapping continuous input indices to sizes for variable selection (fallback to hidden_continuous_size if index is not in dictionary)
embedding_sizes – dictionary mapping (string) indices to tuple of number of categorical classes and embedding size
embedding_paddings – list of indices for embeddings which transform the zero’s embedding to a zero vector
embedding_labels – dictionary mapping (string) indices to list of categorical labels
learning_rate – learning rate
log_interval – log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 , will log multiple entries per batch. Defaults to -1.
log_val_interval – frequency with which to log validation set metrics, defaults to log_interval
log_gradient_flow – if to log gradient flow, this takes time and should be only done to diagnose training failures
reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10
monotone_constaints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g.
"0"
for first position) to constraint (-1
for negative and+1
for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.share_single_variable_networks (bool) – if to share the single variable networks between the encoder and decoder. Defaults to False.
causal_attention (bool) – If to attend only at previous timesteps in the decoder or also include future predictions. Defaults to True.
logging_metrics (nn.ModuleList[LightningMetric]) – list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).
**kwargs – additional arguments to
BaseModel
.
Methods
create_log
(x, y, out, batch_idx, **kwargs)Create the log used in the training and validation step.
expand_static_context
(context, timesteps)add time dimension to static context
forward
(x)input dimensions: n_samples x time x variables
from_dataset
(dataset[, ...])Create model from dataset.
get_attention_mask
(encoder_lengths, ...)Returns causal mask to apply for self-attention layer.
interpret_output
(out[, reduction, ...])interpret output of model
Log embeddings to tensorboard
log_interpretation
(outputs)Log interpretation metrics to tensorboard.
on_epoch_end
(outputs)run at epoch end for training or validation
Called at the very end of fit.
plot_interpretation
(interpretation)Make figures that interpret model.
plot_prediction
(x, out, idx[, ...])Plot actuals vs prediction and attention
- create_log(x, y, out, batch_idx, **kwargs)[source]#
Create the log used in the training and validation step.
- Parameters:
x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader
y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader
out (Dict[str, torch.Tensor]) – output of the network
batch_idx (int) – batch number
prediction_kwargs (Dict[str, Any], optional) – arguments to pass to
to_prediction()
. Defaults to {}.quantiles_kwargs (Dict[str, Any], optional) –
to_quantiles()
. Defaults to {}.
- Returns:
log dictionary to be returned by training and validation steps
- Return type:
Dict[str, Any]
- forward(x: Dict[str, Tensor]) Dict[str, Tensor] [source]#
input dimensions: n_samples x time x variables
- classmethod from_dataset(dataset: TimeSeriesDataSet, allowed_encoder_known_variable_names: List[str] = None, **kwargs)[source]#
Create model from dataset.
- Parameters:
dataset – timeseries dataset
allowed_encoder_known_variable_names – List of known variables that are allowed in encoder, defaults to all
**kwargs – additional arguments such as hyperparameters for model (see
__init__()
)
- Returns:
TemporalFusionTransformer
- get_attention_mask(encoder_lengths: LongTensor, decoder_lengths: LongTensor)[source]#
Returns causal mask to apply for self-attention layer.
- interpret_output(out: Dict[str, Tensor], reduction: str = 'none', attention_prediction_horizon: int = 0) Dict[str, Tensor] [source]#
interpret output of model
- Parameters:
out – output as produced by
forward()
reduction – “none” for no averaging over batches, “sum” for summing attentions, “mean” for normalizing by encode lengths
attention_prediction_horizon – which prediction horizon to use for attention
- Returns:
interpretations that can be plotted with
plot_interpretation()
- plot_interpretation(interpretation: Dict[str, Tensor])[source]#
Make figures that interpret model.
Attention
Variable selection weights / importances
- Parameters:
interpretation – as obtained from
interpret_output()
- Returns:
dictionary of matplotlib figures
- plot_prediction(x: Dict[str, Tensor], out: Dict[str, Tensor], idx: int, plot_attention: bool = True, add_loss_to_title: bool = False, show_future_observed: bool = True, ax=None, **kwargs)[source]#
Plot actuals vs prediction and attention
- Parameters:
x (Dict[str, torch.Tensor]) – network input
out (Dict[str, torch.Tensor]) – network output
idx (int) – sample index
plot_attention – if to plot attention on secondary axis
add_loss_to_title – if to add loss to title. Default to False.
show_future_observed – if to show actuals for future. Defaults to True.
ax – matplotlib axes to plot on
- Returns:
matplotlib figure
- Return type:
plt.Figure