TFT#
- class pytorch_forecasting.models.temporal_fusion_transformer._tft_v2.TFT(loss: Module, logging_metrics: list[Module] | None = None, optimizer: Optimizer | str | None = 'adam', optimizer_params: dict | None = None, lr_scheduler: str | None = None, lr_scheduler_params: dict | None = None, hidden_size: int = 64, num_layers: int = 2, attention_head_size: int = 4, dropout: float = 0.1, metadata: dict | None = None, output_size: int = 1)[source]#
Bases:
BaseModelBase model for time series forecasting.
- Parameters:
loss (nn.Module) – Loss function to use for training.
logging_metrics (Optional[List[nn.Module]], optional) – List of metrics to log during training, validation, and testing.
optimizer (Optional[Union[Optimizer, str]], optional) – Optimizer to use for training. Can be a string (“adam”, “sgd”) or an instance of torch.optim.Optimizer.
optimizer_params (Optional[Dict], optional) – Parameters for the optimizer.
lr_scheduler (Optional[str], optional) – Learning rate scheduler to use. Supported values: “reduce_lr_on_plateau”, “step_lr”.
lr_scheduler_params (Optional[Dict], optional) – Parameters for the learning rate scheduler.
Methods
forward(x)Forward pass of the TFT model.
- forward(x: dict[str, Tensor]) dict[str, Tensor][source]#
Forward pass of the TFT model.
- Parameters:
x (Dict[str, torch.Tensor]) – Dictionary containing input tensors: - encoder_cat: Categorical encoder features - encoder_cont: Continuous encoder features - decoder_cat: Categorical decoder features - decoder_cont: Continuous decoder features - static_categorical_features: Static categorical features - static_continuous_features: Static continuous features
- Returns:
Dictionary containing output tensors: - prediction: Prediction output (batch_size, prediction_length, output_size)
- Return type:
Dict[str, torch.Tensor]