TemporalFusionTransformer#
- class pytorch_forecasting.models.temporal_fusion_transformer._tft.TemporalFusionTransformer(hidden_size: int = 16, lstm_layers: int = 1, dropout: float = 0.1, output_size: int | list[int] = 7, loss: MultiHorizonMetric = None, attention_head_size: int = 4, max_encoder_length: int = 10, static_categoricals: list[str] | None = None, static_reals: list[str] | None = None, time_varying_categoricals_encoder: list[str] | None = None, time_varying_categoricals_decoder: list[str] | None = None, categorical_groups: dict | list[str] | None = None, time_varying_reals_encoder: list[str] | None = None, time_varying_reals_decoder: list[str] | None = None, x_reals: list[str] | None = None, x_categoricals: list[str] | None = None, hidden_continuous_size: int = 8, hidden_continuous_sizes: dict[str, int] | None = None, embedding_sizes: dict[str, tuple[int, int]] | None = None, embedding_paddings: list[str] | None = None, embedding_labels: dict[str, ndarray] | None = None, learning_rate: float = 0.001, log_interval: int | float = -1, log_val_interval: int | float = None, log_gradient_flow: bool = False, reduce_on_plateau_patience: int = 1000, monotone_constaints: dict[str, int] | None = None, share_single_variable_networks: bool = False, causal_attention: bool = True, logging_metrics: ModuleList = None, mask_bias: float = -1000000000.0, **kwargs)[source]#
Bases:
BaseModelWithCovariatesTemporal Fusion Transformer for forecasting timeseries.
Initialize via
from_dataset()method if possible.Implementation of Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting.
Enhancements compared to the original implementation:
static variables can be continuous
multiple categorical variables can be summarized with an EmbeddingBag
variable encoder and decoder length by sample
categorical embeddings are not transformed by variable selection network (because it is a redundant operation)
variable dimension in variable selection network are scaled up via linear interpolation to reduce number of parameters
non-linear variable processing in variable selection network can be shared among decoder and encoder (not shared by default)
capabilities added through base model such as monotone constraints
Tune its hyperparameters with
optimize_hyperparameters().- Parameters:
hidden_size (int, default=16) – hidden size of network which is its main hyperparameter. Can range from 8 to 512.
lstm_layers (int, default=1) – number of LSTM layers (2 is mostly optimal)
dropout (float, default=0.1) – dropout rate
output_size (int or list of int, default=7) – number of outputs (e.g. number of quantiles for QuantileLoss and one target or list of output sizes).
loss (MultiHorizonMetric, default=QuantileLoss()) – loss function taking prediction and targets
attention_head_size (int, default=4) – number of attention heads (4 is a good default)
max_encoder_length (int, default=10) – length to encode, can be far longer than the decoder length but does not have to be
static_categoricals (names of static categorical variables)
static_reals (names of static continuous variables)
time_varying_categoricals_encoder (names of categorical variables for encoder)
time_varying_categoricals_decoder (names of categorical variables for decoder)
time_varying_reals_encoder (names of continuous variables for encoder)
time_varying_reals_decoder (names of continuous variables for decoder)
categorical_groups (dictionary where values) – are list of categorical variables that are forming together a new categorical variable which is the key in the dictionary
x_reals (order of continuous variables in tensor passed to forward function)
x_categoricals (order of categorical variables in tensor passed to forward function)
hidden_continuous_size (default for hidden size for processing continous variables (similar to categorical) – embedding size)
hidden_continuous_sizes (dictionary mapping continuous input indices to sizes for variable selection) – (fallback to hidden_continuous_size if index is not in dictionary)
embedding_sizes (dictionary mapping (string) indices to tuple of number of categorical classes and) – embedding size
embedding_paddings (list of indices for embeddings which transform the zero's embedding to a zero vector)
embedding_labels (dictionary mapping (string) indices to list of categorical labels)
learning_rate (learning rate)
log_interval (log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0) – , will log multiple entries per batch. Defaults to -1.
log_val_interval (frequency with which to log validation set metrics, defaults to log_interval)
log_gradient_flow (if to log gradient flow, this takes time and should be only done to diagnose training) – failures
(int) (reduce_on_plateau_patience)
(Dict[str (monotone_constaints) – variables mapping position (e.g.
"0"for first position) to constraint (-1for negative and+1for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.int]) (dictionary of monotonicity constraints for continuous decoder) – variables mapping position (e.g.
"0"for first position) to constraint (-1for negative and+1for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.(bool) (causal_attention) – decoder. Defaults to False.
(bool) – predictions. Defaults to True.
(nn.ModuleList[LightningMetric]) (logging_metrics) – Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).
mask_bias (float, optional) – Bias for the mask in ScaledDotProductAttention.forward, by default -1e9. Set to -float(“inf”) to allow mixed precision training.
**kwargs (additional arguments to
BaseModel.)
BaseModel for timeseries forecasting from which to inherit from
- Parameters:
log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.
log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.
learning_rate (float, optional) – Learning rate. Defaults to 1e-3.
log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.
loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().
logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].
reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000
reduce_on_plateau_reduction (float) – reduction in learning rate when encountering plateau. Defaults to 2.0.
reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateua learning rate scheduler. Defaults to 1e-5
weight_decay (float) – weight decay. Defaults to 0.0.
optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.
monotone_constraints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g.
"0"for first position) to constraint (-1for negative and+1for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to
lambda out: out["prediction"].optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in
torch.optimorpytorch_optimizer. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay). Defaults to “adam”.
Methods
create_log(x, y, out, batch_idx, **kwargs)Create the log used in the training and validation step.
expand_static_context(context, timesteps)add time dimension to static context
forward(x)input dimensions: n_samples x time x variables
from_dataset(dataset[, ...])Create model from dataset.
get_attention_mask(encoder_lengths, ...)Returns causal mask to apply for self-attention layer.
interpret_output(out[, reduction, ...])interpret output of model
Log embeddings to tensorboard
log_interpretation(outputs)Log interpretation metrics to tensorboard.
on_epoch_end(outputs)run at epoch end for training or validation
Called at the very end of fit.
plot_interpretation(interpretation)Make figures that interpret model.
plot_prediction(x, out, idx[, ...])Plot actuals vs prediction and attention
- create_log(x, y, out, batch_idx, **kwargs)[source]#
Create the log used in the training and validation step.
- Parameters:
x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader
y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader
out (Dict[str, torch.Tensor]) – output of the network
batch_idx (int) – batch number
prediction_kwargs (Dict[str, Any], optional) – arguments to pass to
to_prediction(). Defaults to {}.quantiles_kwargs (Dict[str, Any], optional) –
to_quantiles(). Defaults to {}.
- Returns:
log dictionary to be returned by training and validation steps
- Return type:
Dict[str, Any]
- forward(x: dict[str, Tensor]) dict[str, Tensor][source]#
input dimensions: n_samples x time x variables
- classmethod from_dataset(dataset: TimeSeriesDataSet, allowed_encoder_known_variable_names: list[str] = None, **kwargs)[source]#
Create model from dataset.
- Parameters:
dataset – timeseries dataset
allowed_encoder_known_variable_names – List of known variables that are allowed in encoder, defaults to all
**kwargs – additional arguments such as hyperparameters for model (see
__init__())
- Returns:
TemporalFusionTransformer
- get_attention_mask(encoder_lengths: LongTensor, decoder_lengths: LongTensor)[source]#
Returns causal mask to apply for self-attention layer.
- interpret_output(out: dict[str, Tensor], reduction: str = 'none', attention_prediction_horizon: int = 0) dict[str, Tensor][source]#
interpret output of model
- Parameters:
out – output as produced by
forward()reduction – “none” for no averaging over batches, “sum” for summing attentions, “mean” for normalizing by encode lengths
attention_prediction_horizon – which prediction horizon to use for attention
- Returns:
interpretations that can be plotted with
plot_interpretation()
- plot_interpretation(interpretation: dict[str, Tensor])[source]#
Make figures that interpret model.
Attention
Variable selection weights / importances
- Parameters:
interpretation – as obtained from
interpret_output()- Returns:
dictionary of matplotlib figures
- plot_prediction(x: dict[str, Tensor], out: dict[str, Tensor], idx: int, plot_attention: bool = True, add_loss_to_title: bool = False, show_future_observed: bool = True, ax=None, **kwargs)[source]#
Plot actuals vs prediction and attention
- Parameters:
x (Dict[str, torch.Tensor]) – network input
out (Dict[str, torch.Tensor]) – network output
idx (int) – sample index
plot_attention – if to plot attention on secondary axis
add_loss_to_title – if to add loss to title. Default to False.
show_future_observed – if to show actuals for future. Defaults to True.
ax – matplotlib axes to plot on
- Returns:
matplotlib figure
- Return type:
plt.Figure