TFT#

class pytorch_forecasting.models.temporal_fusion_transformer._tft_v2.TFT(loss: Module, logging_metrics: list[Module] | None = None, optimizer: Optimizer | str | None = 'adam', optimizer_params: dict | None = None, lr_scheduler: str | None = None, lr_scheduler_params: dict | None = None, hidden_size: int = 64, num_layers: int = 2, attention_head_size: int = 4, dropout: float = 0.1, metadata: dict | None = None, output_size: int = 1)[source]#

Bases: BaseModel

Methods

forward(x)

Forward pass of the TFT model.

forward(x: dict[str, Tensor]) dict[str, Tensor][source]#

Forward pass of the TFT model.

Parameters:

x (Dict[str, torch.Tensor]) – Dictionary containing input tensors: - encoder_cat: Categorical encoder features - encoder_cont: Continuous encoder features - decoder_cat: Categorical decoder features - decoder_cont: Continuous decoder features - static_categorical_features: Static categorical features - static_continuous_features: Static continuous features

Returns:

Dictionary containing output tensors: - prediction: Prediction output (batch_size, prediction_length, output_size)

Return type:

Dict[str, torch.Tensor]