optimize_hyperparameters#
- pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters(train_dataloaders: DataLoader, val_dataloaders: DataLoader, model_path: str, max_epochs: int = 20, n_trials: int = 100, timeout: float = 28800.0, gradient_clip_val_range: tuple[float, float] = (0.01, 100.0), hidden_size_range: tuple[int, int] = (16, 265), hidden_continuous_size_range: tuple[int, int] = (8, 64), attention_head_size_range: tuple[int, int] = (1, 4), dropout_range: tuple[float, float] = (0.1, 0.3), learning_rate_range: tuple[float, float] = (1e-05, 1.0), use_learning_rate_finder: bool = True, trainer_kwargs: dict[str, Any] = {}, log_dir: str = 'lightning_logs', study=None, verbose: int | bool = None, pruner=None, **kwargs)[source]#
Optimize hyperparameters of a Temporal Fusion Transformer model.
Runs hyperparameter optimization using Optuna. The learning rate can optionally be determined using the PyTorch Lightning learning rate finder.
- Parameters:
train_dataloaders (DataLoader) – Dataloader for training.
val_dataloaders (DataLoader) – Dataloader for validation.
model_path (str) – Directory where model checkpoints are saved.
max_epochs (int, optional) – Maximum number of training epochs. Default is 20.
n_trials (int, optional) – Number of hyperparameter trials. Default is 100.
timeout (float, optional) – Maximum time in seconds for optimization. Default is 8 hours.
gradient_clip_val_range (tuple of float, optional) – Range for gradient clipping values.
hidden_size_range (tuple of int, optional) – Range for hidden size.
hidden_continuous_size_range (tuple of int, optional) – Range for hidden continuous size.
attention_head_size_range (tuple of int, optional) – Range for attention head size.
dropout_range (tuple of float, optional) – Range for dropout values.
learning_rate_range (tuple of float, optional) – Range for learning rate.
use_learning_rate_finder (bool, optional) – Whether to use the Lightning learning rate finder.
trainer_kwargs (dict of str to Any, optional) – Additional arguments passed to the PyTorch Lightning Trainer.
log_dir (str, optional) – Directory for TensorBoard logs.
study (optuna.Study, optional) – Existing Optuna study to resume.
verbose (int or bool, optional) – Verbosity level.
pruner (optuna.pruners.BasePruner, optional) – Optuna pruner to use.
**kwargs – Additional keyword arguments passed to
TemporalFusionTransformer.
- Returns:
The resulting Optuna study.
- Return type:
optuna.Study
- Raises:
ImportError – If required optional dependencies are not installed.