"""
Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`_.
"""
import copy
import logging
import os
from typing import Any, Dict, Tuple, Union
import numpy as np
import optuna
from optuna.integration import PyTorchLightningPruningCallback, TensorBoardCallback
import optuna.logging
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import statsmodels.api as sm
import torch
from torch.utils.data import DataLoader
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
optuna_logger = logging.getLogger("optuna")
[docs]class MetricsCallback(Callback):
"""PyTorch Lightning metric callback."""
def __init__(self):
super().__init__()
self.metrics = []
[docs] def on_validation_end(self, trainer, pl_module):
self.metrics.append(trainer.callback_metrics)
[docs]def optimize_hyperparameters(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
model_path: str,
max_epochs: int = 20,
n_trials: int = 100,
timeout: float = 3600 * 8.0, # 8 hours
gradient_clip_val_range: Tuple[float, float] = (0.01, 100.0),
hidden_size_range: Tuple[int, int] = (16, 265),
hidden_continuous_size_range: Tuple[int, int] = (8, 64),
attention_head_size_range: Tuple[int, int] = (1, 4),
dropout_range: Tuple[float, float] = (0.1, 0.3),
learning_rate_range: Tuple[float, float] = (1e-5, 1.0),
use_learning_rate_finder: bool = True,
trainer_kwargs: Dict[str, Any] = {},
log_dir: str = "lightning_logs",
study: optuna.Study = None,
verbose: Union[int, bool] = None,
pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(),
**kwargs,
) -> optuna.Study:
"""
Optimize Temporal Fusion Transformer hyperparameters.
Run hyperparameter optimization. Learning rate for is determined with
the PyTorch Lightning learning rate finder.
Args:
train_dataloader (DataLoader): dataloader for training model
val_dataloader (DataLoader): dataloader for validating model
model_path (str): folder to which model checkpoints are saved
max_epochs (int, optional): Maximum number of epochs to run training. Defaults to 20.
n_trials (int, optional): Number of hyperparameter trials to run. Defaults to 100.
timeout (float, optional): Time in seconds after which training is stopped regardless of number of epochs
or validation metric. Defaults to 3600*8.0.
hidden_size_range (Tuple[int, int], optional): Minimum and maximum of ``hidden_size`` hyperparameter. Defaults
to (16, 265).
hidden_continuous_size_range (Tuple[int, int], optional): Minimum and maximum of ``hidden_continuous_size``
hyperparameter. Defaults to (8, 64).
attention_head_size_range (Tuple[int, int], optional): Minimum and maximum of ``attention_head_size``
hyperparameter. Defaults to (1, 4).
dropout_range (Tuple[float, float], optional): Minimum and maximum of ``dropout`` hyperparameter. Defaults to
(0.1, 0.3).
learning_rate_range (Tuple[float, float], optional): Learning rate range. Defaults to (1e-5, 1.0).
use_learning_rate_finder (bool): If to use learning rate finder or optimize as part of hyperparameters.
Defaults to True.
trainer_kwargs (Dict[str, Any], optional): Additional arguments to the
`PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`_ such
as ``limit_train_batches``. Defaults to {}.
log_dir (str, optional): Folder into which to log results for tensorboard. Defaults to "lightning_logs".
study (optuna.Study, optional): study to resume. Will create new study by default.
verbose (Union[int, bool]): level of verbosity.
* None: no change in verbosity level (equivalent to verbose=1 by optuna-set default).
* 0 or False: log only warnings.
* 1 or True: log pruning events.
* 2: optuna logging level at debug level.
Defaults to None.
pruner (optuna.pruners.BasePruner, optional): The optuna pruner to use.
Defaults to optuna.pruners.SuccessiveHalvingPruner().
**kwargs: Additional arguments for the :py:class:`~TemporalFusionTransformer`.
Returns:
optuna.Study: optuna study results
"""
assert isinstance(train_dataloader.dataset, TimeSeriesDataSet) and isinstance(
val_dataloader.dataset, TimeSeriesDataSet
), "dataloaders must be built from timeseriesdataset"
logging_level = {
None: optuna.logging.get_verbosity(),
0: optuna.logging.WARNING,
1: optuna.logging.INFO,
2: optuna.logging.DEBUG,
}
optuna_verbose = logging_level[verbose]
optuna.logging.set_verbosity(optuna_verbose)
loss = kwargs.get(
"loss", QuantileLoss()
) # need a deepcopy of loss as it will otherwise propagate from one trial to the next
# create objective function
def objective(trial: optuna.Trial) -> float:
# Filenames for each trial must be made unique in order to access each checkpoint.
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss"
)
# The default logger in PyTorch Lightning writes to event files to be consumed by
# TensorBoard. We don't use any logger here as it requires us to implement several abstract
# methods. Instead we setup a simple callback, that saves metrics from each validation step.
metrics_callback = MetricsCallback()
learning_rate_callback = LearningRateMonitor()
logger = TensorBoardLogger(log_dir, name="optuna", version=trial.number)
gradient_clip_val = trial.suggest_loguniform("gradient_clip_val", *gradient_clip_val_range)
default_trainer_kwargs = dict(
gpus=[0] if torch.cuda.is_available() else None,
max_epochs=max_epochs,
gradient_clip_val=gradient_clip_val,
callbacks=[
metrics_callback,
learning_rate_callback,
checkpoint_callback,
PyTorchLightningPruningCallback(trial, monitor="val_loss"),
],
logger=logger,
progress_bar_refresh_rate=[0, 1][optuna_verbose < optuna.logging.INFO],
weights_summary=[None, "top"][optuna_verbose < optuna.logging.INFO],
)
default_trainer_kwargs.update(trainer_kwargs)
trainer = pl.Trainer(
**default_trainer_kwargs,
)
# create model
hidden_size = trial.suggest_int("hidden_size", *hidden_size_range, log=True)
kwargs["loss"] = copy.deepcopy(loss)
model = TemporalFusionTransformer.from_dataset(
train_dataloader.dataset,
dropout=trial.suggest_uniform("dropout", *dropout_range),
hidden_size=hidden_size,
hidden_continuous_size=trial.suggest_int(
"hidden_continuous_size",
hidden_continuous_size_range[0],
min(hidden_continuous_size_range[1], hidden_size),
log=True,
),
attention_head_size=trial.suggest_int("attention_head_size", *attention_head_size_range),
log_interval=-1,
**kwargs,
)
# find good learning rate
if use_learning_rate_finder:
lr_trainer = pl.Trainer(
gradient_clip_val=gradient_clip_val,
gpus=[0] if torch.cuda.is_available() else None,
logger=False,
progress_bar_refresh_rate=0,
weights_summary=None,
)
res = lr_trainer.tuner.lr_find(
model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
early_stop_threshold=10000,
min_lr=learning_rate_range[0],
num_training=100,
max_lr=learning_rate_range[1],
)
loss_finite = np.isfinite(res.results["loss"])
if loss_finite.sum() > 3: # at least 3 valid values required for learning rate finder
lr_smoothed, loss_smoothed = sm.nonparametric.lowess(
np.asarray(res.results["loss"])[loss_finite],
np.asarray(res.results["lr"])[loss_finite],
frac=1.0 / 10.0,
)[min(loss_finite.sum() - 3, 10) : -1].T
optimal_idx = np.gradient(loss_smoothed).argmin()
optimal_lr = lr_smoothed[optimal_idx]
else:
optimal_idx = np.asarray(res.results["loss"]).argmin()
optimal_lr = res.results["lr"][optimal_idx]
optuna_logger.info(f"Using learning rate of {optimal_lr:.3g}")
# add learning rate artificially
model.hparams.learning_rate = trial.suggest_uniform("learning_rate", optimal_lr, optimal_lr)
else:
model.hparams.learning_rate = trial.suggest_loguniform("learning_rate", *learning_rate_range)
# fit
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)
# report result
return metrics_callback.metrics[-1]["val_loss"].item()
# setup optuna and run
if study is None:
study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=n_trials, timeout=timeout)
return study