pytorch_forecasting.models.temporal_fusion_transformer.tuning.
optimize_hyperparameters
Optimize Temporal Fusion Transformer hyperparameters.
Run hyperparameter optimization. Learning rate for is determined with the PyTorch Lightning learning rate finder.
train_dataloader (DataLoader) – dataloader for training model
val_dataloader (DataLoader) – dataloader for validating model
model_path (str) – folder to which model checkpoints are saved
max_epochs (int, optional) – Maximum number of epochs to run training. Defaults to 20.
n_trials (int, optional) – Number of hyperparameter trials to run. Defaults to 100.
timeout (float, optional) – Time in seconds after which training is stopped regardless of number of epochs or validation metric. Defaults to 3600*8.0.
hidden_size_range (Tuple[int, int], optional) – Minimum and maximum of hidden_size hyperparameter. Defaults to (16, 265).
hidden_size
hidden_continuous_size_range (Tuple[int, int], optional) – Minimum and maximum of hidden_continuous_size hyperparameter. Defaults to (8, 64).
hidden_continuous_size
attention_head_size_range (Tuple[int, int], optional) – Minimum and maximum of attention_head_size hyperparameter. Defaults to (1, 4).
attention_head_size
dropout_range (Tuple[float, float], optional) – Minimum and maximum of dropout hyperparameter. Defaults to (0.1, 0.3).
dropout
learning_rate_range (Tuple[float, float], optional) – Learning rate range. Defaults to (1e-5, 1.0).
use_learning_rate_finder (bool) – If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
trainer_kwargs (Dict[str, Any], optional) – Additional arguments to the PyTorch Lightning trainer such as limit_train_batches. Defaults to {}.
limit_train_batches
log_dir (str, optional) – Folder into which to log results for tensorboard. Defaults to “lightning_logs”.
study (optuna.Study, optional) – study to resume. Will create new study by default.
verbose (Union[int, bool]) – level of verbosity. * None: no change in verbosity level (equivalent to verbose=1 by optuna-set default). * 0 or False: log only warnings. * 1 or True: log pruning events. * 2: optuna logging level at debug level. Defaults to None.
**kwargs – Additional arguments for the TemporalFusionTransformer.
TemporalFusionTransformer
optuna study results
optuna.Study