TFT#

class pytorch_forecasting.models.temporal_fusion_transformer._tft_v2.TFT(loss: Module, logging_metrics: list[Module] | None = None, optimizer: Optimizer | str | None = 'adam', optimizer_params: dict | None = None, lr_scheduler: str | None = None, lr_scheduler_params: dict | None = None, hidden_size: int = 64, num_layers: int = 2, attention_head_size: int = 4, dropout: float = 0.1, metadata: dict | None = None, output_size: int = 1)[source]#

Bases: BaseModel

Base model for time series forecasting.

Parameters:
  • loss (nn.Module) – Loss function to use for training.

  • logging_metrics (Optional[List[nn.Module]], optional) – List of metrics to log during training, validation, and testing.

  • optimizer (Optional[Union[Optimizer, str]], optional) – Optimizer to use for training. Can be a string (“adam”, “sgd”) or an instance of torch.optim.Optimizer.

  • optimizer_params (Optional[Dict], optional) – Parameters for the optimizer.

  • lr_scheduler (Optional[str], optional) – Learning rate scheduler to use. Supported values: “reduce_lr_on_plateau”, “step_lr”.

  • lr_scheduler_params (Optional[Dict], optional) – Parameters for the learning rate scheduler.

Methods

forward(x)

Forward pass of the TFT model.

forward(x: dict[str, Tensor]) dict[str, Tensor][source]#

Forward pass of the TFT model.

Parameters:

x (Dict[str, torch.Tensor]) – Dictionary containing input tensors: - encoder_cat: Categorical encoder features - encoder_cont: Continuous encoder features - decoder_cat: Categorical decoder features - decoder_cont: Continuous decoder features - static_categorical_features: Static categorical features - static_continuous_features: Static continuous features

Returns:

Dictionary containing output tensors: - prediction: Prediction output (batch_size, prediction_length, output_size)

Return type:

Dict[str, torch.Tensor]