Source code for pytorch_forecasting.models.temporal_fusion_transformer.tuning

"""
Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`_.
"""
import copy
import logging
import os
from typing import Any, Dict, Tuple, Union

import lightning.pytorch as pl
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
import numpy as np
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import optuna.logging
import statsmodels.api as sm
import torch
from torch.utils.data import DataLoader

from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss

optuna_logger = logging.getLogger("optuna")


# need to inherit from callback for this to work
[docs]class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback): pass
[docs]def optimize_hyperparameters( train_dataloaders: DataLoader, val_dataloaders: DataLoader, model_path: str, max_epochs: int = 20, n_trials: int = 100, timeout: float = 3600 * 8.0, # 8 hours gradient_clip_val_range: Tuple[float, float] = (0.01, 100.0), hidden_size_range: Tuple[int, int] = (16, 265), hidden_continuous_size_range: Tuple[int, int] = (8, 64), attention_head_size_range: Tuple[int, int] = (1, 4), dropout_range: Tuple[float, float] = (0.1, 0.3), learning_rate_range: Tuple[float, float] = (1e-5, 1.0), use_learning_rate_finder: bool = True, trainer_kwargs: Dict[str, Any] = {}, log_dir: str = "lightning_logs", study: optuna.Study = None, verbose: Union[int, bool] = None, pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(), **kwargs, ) -> optuna.Study: """ Optimize Temporal Fusion Transformer hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the PyTorch Lightning learning rate finder. Args: train_dataloaders (DataLoader): dataloader for training model val_dataloaders (DataLoader): dataloader for validating model model_path (str): folder to which model checkpoints are saved max_epochs (int, optional): Maximum number of epochs to run training. Defaults to 20. n_trials (int, optional): Number of hyperparameter trials to run. Defaults to 100. timeout (float, optional): Time in seconds after which training is stopped regardless of number of epochs or validation metric. Defaults to 3600*8.0. hidden_size_range (Tuple[int, int], optional): Minimum and maximum of ``hidden_size`` hyperparameter. Defaults to (16, 265). hidden_continuous_size_range (Tuple[int, int], optional): Minimum and maximum of ``hidden_continuous_size`` hyperparameter. Defaults to (8, 64). attention_head_size_range (Tuple[int, int], optional): Minimum and maximum of ``attention_head_size`` hyperparameter. Defaults to (1, 4). dropout_range (Tuple[float, float], optional): Minimum and maximum of ``dropout`` hyperparameter. Defaults to (0.1, 0.3). learning_rate_range (Tuple[float, float], optional): Learning rate range. Defaults to (1e-5, 1.0). use_learning_rate_finder (bool): If to use learning rate finder or optimize as part of hyperparameters. Defaults to True. trainer_kwargs (Dict[str, Any], optional): Additional arguments to the `PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`_ such as ``limit_train_batches``. Defaults to {}. log_dir (str, optional): Folder into which to log results for tensorboard. Defaults to "lightning_logs". study (optuna.Study, optional): study to resume. Will create new study by default. verbose (Union[int, bool]): level of verbosity. * None: no change in verbosity level (equivalent to verbose=1 by optuna-set default). * 0 or False: log only warnings. * 1 or True: log pruning events. * 2: optuna logging level at debug level. Defaults to None. pruner (optuna.pruners.BasePruner, optional): The optuna pruner to use. Defaults to optuna.pruners.SuccessiveHalvingPruner(). **kwargs: Additional arguments for the :py:class:`~TemporalFusionTransformer`. Returns: optuna.Study: optuna study results """ assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance( val_dataloaders.dataset, TimeSeriesDataSet ), "dataloaders must be built from timeseriesdataset" logging_level = { None: optuna.logging.get_verbosity(), 0: optuna.logging.WARNING, 1: optuna.logging.INFO, 2: optuna.logging.DEBUG, } optuna_verbose = logging_level[verbose] optuna.logging.set_verbosity(optuna_verbose) loss = kwargs.get( "loss", QuantileLoss() ) # need a deepcopy of loss as it will otherwise propagate from one trial to the next # create objective function def objective(trial: optuna.Trial) -> float: # Filenames for each trial must be made unique in order to access each checkpoint. checkpoint_callback = ModelCheckpoint( dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss" ) learning_rate_callback = LearningRateMonitor() logger = TensorBoardLogger(log_dir, name="optuna", version=trial.number) gradient_clip_val = trial.suggest_loguniform("gradient_clip_val", *gradient_clip_val_range) default_trainer_kwargs = dict( accelerator="auto", max_epochs=max_epochs, gradient_clip_val=gradient_clip_val, callbacks=[ learning_rate_callback, checkpoint_callback, PyTorchLightningPruningCallbackAdjusted(trial, monitor="val_loss"), ], logger=logger, enable_progress_bar=optuna_verbose < optuna.logging.INFO, enable_model_summary=[False, True][optuna_verbose < optuna.logging.INFO], ) default_trainer_kwargs.update(trainer_kwargs) trainer = pl.Trainer( **default_trainer_kwargs, ) # create model hidden_size = trial.suggest_int("hidden_size", *hidden_size_range, log=True) kwargs["loss"] = copy.deepcopy(loss) model = TemporalFusionTransformer.from_dataset( train_dataloaders.dataset, dropout=trial.suggest_uniform("dropout", *dropout_range), hidden_size=hidden_size, hidden_continuous_size=trial.suggest_int( "hidden_continuous_size", hidden_continuous_size_range[0], min(hidden_continuous_size_range[1], hidden_size), log=True, ), attention_head_size=trial.suggest_int("attention_head_size", *attention_head_size_range), log_interval=-1, **kwargs, ) # find good learning rate if use_learning_rate_finder: lr_trainer = pl.Trainer( gradient_clip_val=gradient_clip_val, accelerator="auto", logger=False, enable_progress_bar=False, enable_model_summary=False, ) tuner = Tuner(lr_trainer) res = tuner.lr_find( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, early_stop_threshold=10000, min_lr=learning_rate_range[0], num_training=100, max_lr=learning_rate_range[1], ) loss_finite = np.isfinite(res.results["loss"]) if loss_finite.sum() > 3: # at least 3 valid values required for learning rate finder lr_smoothed, loss_smoothed = sm.nonparametric.lowess( np.asarray(res.results["loss"])[loss_finite], np.asarray(res.results["lr"])[loss_finite], frac=1.0 / 10.0, )[min(loss_finite.sum() - 3, 10) : -1].T optimal_idx = np.gradient(loss_smoothed).argmin() optimal_lr = lr_smoothed[optimal_idx] else: optimal_idx = np.asarray(res.results["loss"]).argmin() optimal_lr = res.results["lr"][optimal_idx] optuna_logger.info(f"Using learning rate of {optimal_lr:.3g}") # add learning rate artificially model.hparams.learning_rate = trial.suggest_uniform("learning_rate", optimal_lr, optimal_lr) else: model.hparams.learning_rate = trial.suggest_loguniform("learning_rate", *learning_rate_range) # fit trainer.fit(model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) # report result return trainer.callback_metrics["val_loss"].item() # setup optuna and run if study is None: study = optuna.create_study(direction="minimize", pruner=pruner) study.optimize(objective, n_trials=n_trials, timeout=timeout) return study