optimize_hyperparameters#

pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters(train_dataloaders: DataLoader, val_dataloaders: DataLoader, model_path: str, max_epochs: int = 20, n_trials: int = 100, timeout: float = 28800.0, gradient_clip_val_range: tuple[float, float] = (0.01, 100.0), hidden_size_range: tuple[int, int] = (16, 265), hidden_continuous_size_range: tuple[int, int] = (8, 64), attention_head_size_range: tuple[int, int] = (1, 4), dropout_range: tuple[float, float] = (0.1, 0.3), learning_rate_range: tuple[float, float] = (1e-05, 1.0), use_learning_rate_finder: bool = True, trainer_kwargs: dict[str, Any] = {}, log_dir: str = 'lightning_logs', study=None, verbose: int | bool = None, pruner=None, **kwargs)[source]#

Optimize hyperparameters of a Temporal Fusion Transformer model.

Runs hyperparameter optimization using Optuna. The learning rate can optionally be determined using the PyTorch Lightning learning rate finder.

Parameters:
  • train_dataloaders (DataLoader) – Dataloader for training.

  • val_dataloaders (DataLoader) – Dataloader for validation.

  • model_path (str) – Directory where model checkpoints are saved.

  • max_epochs (int, optional) – Maximum number of training epochs. Default is 20.

  • n_trials (int, optional) – Number of hyperparameter trials. Default is 100.

  • timeout (float, optional) – Maximum time in seconds for optimization. Default is 8 hours.

  • gradient_clip_val_range (tuple of float, optional) – Range for gradient clipping values.

  • hidden_size_range (tuple of int, optional) – Range for hidden size.

  • hidden_continuous_size_range (tuple of int, optional) – Range for hidden continuous size.

  • attention_head_size_range (tuple of int, optional) – Range for attention head size.

  • dropout_range (tuple of float, optional) – Range for dropout values.

  • learning_rate_range (tuple of float, optional) – Range for learning rate.

  • use_learning_rate_finder (bool, optional) – Whether to use the Lightning learning rate finder.

  • trainer_kwargs (dict of str to Any, optional) – Additional arguments passed to the PyTorch Lightning Trainer.

  • log_dir (str, optional) – Directory for TensorBoard logs.

  • study (optuna.Study, optional) – Existing Optuna study to resume.

  • verbose (int or bool, optional) – Verbosity level.

  • pruner (optuna.pruners.BasePruner, optional) – Optuna pruner to use.

  • **kwargs – Additional keyword arguments passed to TemporalFusionTransformer.

Returns:

The resulting Optuna study.

Return type:

optuna.Study

Raises:

ImportError – If required optional dependencies are not installed.