In this tutorial, we will train the TemporalFusionTransformer on a very small dataset to demonstrate that it even does a good job on only 20k samples. Generally speaking, it is a large model and will therefore perform much better with more data.
TemporalFusionTransformer
Our example is a demand forecast from the Stallion kaggle competition.
[1]:
import os import warnings warnings.filterwarnings("ignore") # avoid printing out absolute paths os.chdir("../../..")
[2]:
import copy from pathlib import Path import warnings import numpy as np import pandas as pd import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger import torch from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet from pytorch_forecasting.data import GroupNormalizer from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
First, we need to transform our time series into a pandas dataframe where each row can be identified with a time step and a time series. Fortunately, most datasets are already in this format. For this tutorial, we will use the Stallion dataset from Kaggle describing sales of various beverages. Our task is to make a six-month forecast of the sold volume by stock keeping units (SKU), that is products, sold by an agency, that is a store. There are about 21 000 monthly historic sales records. In addition to historic sales we have information about the sales price, the location of the agency, special days such as holidays, and volume sold in the entire industry.
The dataset is already in the correct format but misses some important features. Most importantly, we need to add a time index that is incremented by one for each time step. Further, it is beneficial to add date features, which in this case means extracting the month from the date record.
[3]:
from pytorch_forecasting.data.examples import get_stallion_data data = get_stallion_data() # add time index data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month data["time_idx"] -= data["time_idx"].min() # add additional features data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings data["log_volume"] = np.log(data.volume + 1e-8) data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean") data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean") # we want to encode special days as one variable and thus need to first reverse one-hot encoding special_days = [ "easter_day", "good_friday", "new_year", "christmas", "labor_day", "independence_day", "revolution_day_memorial", "regional_games", "fifa_u_17_world_cup", "football_gold_cup", "beer_capital", "music_fest", ] data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category") data.sample(10, random_state=521)
10 rows × 31 columns
[4]:
data.describe()
The next step is to convert the dataframe into a PyTorch Forecasting TimeSeriesDataSet. Apart from telling the dataset which features are categorical vs continuous and which are static vs varying in time, we also have to decide how we normalise the data. Here, we standard scale each time series separately and indicate that values are always positive. Generally, the EncoderNormalizer, that scales dynamically on each encoder sequence as you train, is preferred to avoid look-ahead bias induced by normalisation. However, you might accept look-ahead bias if you are having troubles to find a reasonably stable normalisation, for example, because there are a lot of zeros in your data. Or you expect a more stable normalization in inference. In the later case, you ensure that you do not learn “weird” jumps that will not be present when running inference, thus training on a more realistic data set.
TimeSeriesDataSet
EncoderNormalizer
We also choose to use the last six months as a validation set.
[5]:
max_prediction_length = 6 max_encoder_length = 24 training_cutoff = data["time_idx"].max() - max_prediction_length training = TimeSeriesDataSet( data[lambda x: x.time_idx <= training_cutoff], time_idx="time_idx", target="volume", group_ids=["agency", "sku"], min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set) max_encoder_length=max_encoder_length, min_prediction_length=1, max_prediction_length=max_prediction_length, static_categoricals=["agency", "sku"], static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], time_varying_known_categoricals=["special_days", "month"], variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"], time_varying_unknown_categoricals=[], time_varying_unknown_reals=[ "volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp", "avg_volume_by_agency", "avg_volume_by_sku", ], target_normalizer=GroupNormalizer( groups=["agency", "sku"], transformation="softplus" ), # use softplus and normalize by group add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True, ) # create validation set (predict=True) which means to predict the last max_prediction_length points in time # for each series validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True) # create dataloaders for model batch_size = 128 # set this between 32 to 128 train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
To learn more about the TimeSeriesDataSet, visit its documentation or the tutorial explaining how to pass datasets to models.
Evaluating a Baseline model that predicts the next 6 months by simply repeating the last observed volume gives us a simle benchmark that we want to outperform.
Baseline
[6]:
# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history actuals = torch.cat([y for x, (y, weight) in iter(val_dataloader)]) baseline_predictions = Baseline().predict(val_dataloader) (actuals - baseline_predictions).abs().mean().item()
293.0088195800781
It is now time to create our TemporalFusionTransformer model. We train the model with PyTorch Lightning.
Prior to training, you can identify the optimal learning rate with the PyTorch Lightning learning rate finder.
[7]:
# configure network and trainer pl.seed_everything(42) trainer = pl.Trainer( gpus=0, # clipping gradients is a hyperparameter and important to prevent divergance # of the gradient for recurrent neural networks gradient_clip_val=0.1, ) tft = TemporalFusionTransformer.from_dataset( training, # not meaningful for finding the learning rate but otherwise very important learning_rate=0.03, hidden_size=16, # most important hyperparameter apart from learning rate # number of attention heads. Set to up to 4 for large datasets attention_head_size=1, dropout=0.1, # between 0.1 and 0.3 are good values hidden_continuous_size=8, # set to <= hidden_size output_size=7, # 7 quantiles by default loss=QuantileLoss(), # reduce learning rate if no improvement in validation loss after x epochs reduce_on_plateau_patience=4, ) print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
GPU available: False, used: False TPU available: None, using: 0 TPU cores
Number of parameters in network: 29.7k
[8]:
# find optimal learning rate res = trainer.tuner.lr_find( tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, max_lr=10.0, min_lr=1e-6, ) print(f"suggested learning rate: {res.suggestion()}") fig = res.plot(show=True, suggest=True) fig.show()
| Name | Type | Params ---------------------------------------------------------------------------------------- 0 | loss | QuantileLoss | 0 1 | logging_metrics | ModuleList | 0 2 | input_embeddings | MultiEmbedding | 1.3 K 3 | prescalers | ModuleDict | 256 4 | static_variable_selection | VariableSelectionNetwork | 3.4 K 5 | encoder_variable_selection | VariableSelectionNetwork | 8.0 K 6 | decoder_variable_selection | VariableSelectionNetwork | 2.7 K 7 | static_context_variable_selection | GatedResidualNetwork | 1.1 K 8 | static_context_initial_hidden_lstm | GatedResidualNetwork | 1.1 K 9 | static_context_initial_cell_lstm | GatedResidualNetwork | 1.1 K 10 | static_context_enrichment | GatedResidualNetwork | 1.1 K 11 | lstm_encoder | LSTM | 2.2 K 12 | lstm_decoder | LSTM | 2.2 K 13 | post_lstm_gate_encoder | GatedLinearUnit | 544 14 | post_lstm_add_norm_encoder | AddNorm | 32 15 | static_enrichment | GatedResidualNetwork | 1.4 K 16 | multihead_attn | InterpretableMultiHeadAttention | 1.1 K 17 | post_attn_gate_norm | GateAddNorm | 576 18 | pos_wise_ff | GatedResidualNetwork | 1.1 K 19 | pre_output_gate_norm | GateAddNorm | 576 20 | output_layer | Linear | 119 ---------------------------------------------------------------------------------------- 29.7 K Trainable params 0 Non-trainable params 29.7 K Total params
Restored states from the checkpoint file at /Users/beitnerjan/Documents/Github/temporal_fusion_transformer_pytorch/lr_find_temp_model.ckpt
suggested learning rate: 0.06760829753919811
For the TemporalFusionTransformer, the optimal learning rate seems to be slightly lower than the suggested one. Further, we do not directly want to use the suggested learning rate because PyTorch Lightning sometimes can get confused by the noise at lower learning rates and suggests rates far too low. Manual control is essential. We decide to pick 0.03 as learning rate.
If you have troubles training the model and get an error AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem', consider either uninstalling tensorflow or first execute
AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf import tensorboard as tb tf.io.gfile = tb.compat.tensorflow_stub.io.gfile.
import tensorflow as tf import tensorboard as tb tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
[9]:
# configure network and trainer early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") lr_logger = LearningRateMonitor() # log the learning rate logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard trainer = pl.Trainer( max_epochs=30, gpus=0, weights_summary="top", gradient_clip_val=0.1, limit_train_batches=30, # coment in for training, running valiation every 30 batches # fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs callbacks=[lr_logger, early_stop_callback], logger=logger, ) tft = TemporalFusionTransformer.from_dataset( training, learning_rate=0.03, hidden_size=16, attention_head_size=1, dropout=0.1, hidden_continuous_size=8, output_size=7, # 7 quantiles by default loss=QuantileLoss(), log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches reduce_on_plateau_patience=4, ) print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
Training takes a couple of minutes on my Macbook but for larger networks and datasets, it can take hours. The training speed is here mostly determined by overhead and choosing a larger batch_size or hidden_size (i.e. network size) does not slow does training linearly making training on large datasets feasible. During training, we can monitor the tensorboard which can be spun up with tensorboard --logdir=lightning_logs. For example, we can monitor examples predictions on the training and validation set.
batch_size
hidden_size
tensorboard --logdir=lightning_logs
[10]:
# fit network trainer.fit( tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, )
1
Hyperparamter tuning with [optuna](https://optuna.org/) is directly build into pytorch-forecasting. For example, we can use the optimize_hyperparameters() function to optimize the TFT’s hyperparameters.
optimize_hyperparameters()
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters # create study study = optimize_hyperparameters( train_dataloader, val_dataloader, model_path="optuna_test", n_trials=200, max_epochs=50, gradient_clip_val_range=(0.01, 1.0), hidden_size_range=(8, 128), hidden_continuous_size_range=(8, 128), attention_head_size_range=(1, 4), learning_rate_range=(0.001, 0.1), dropout_range=(0.1, 0.3), trainer_kwargs=dict(limit_train_batches=30), reduce_on_plateau_patience=4, use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder ) # save study results - also we can resume tuning at a later point in time with open("test_study.pkl", "wb") as fout: pickle.dump(study, fout) # show best hyperparameters print(study.best_trial.params)
PyTorch Lightning automatically checkpoints training and thus, we can easily retrieve the best model and load it.
[11]:
# load the best model according to the validation loss # (given that we use early stopping, this is not necessarily the last epoch) best_model_path = trainer.checkpoint_callback.best_model_path best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
After training, we can make predictions with predict(). The method allows very fine-grained control over what it returns so that, for example, you can easily match predictions to your pandas dataframe. See its documentation for details. We evaluate the metrics on the validation dataset and a couple of examples to see how well the model is doing. Given that we work with only 21 000 samples the results are very reassuring and can compete with results by a gradient booster. We also perform better than the baseline model. Given the noisy data, this is not trivial.
predict()
[12]:
# calcualte mean absolute error on validation set actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]) predictions = best_tft.predict(val_dataloader) (actuals - predictions).abs().mean()
tensor(266.7437)
We can now also look at sample predictions directly which we plot with plot_prediction(). As you can see from the figures below, forecasts look rather accurate. If you wonder, the grey lines denote the amount of attention the model pays to different points in time when making the prediction. This is a special feature of the Temporal Fusion Transformer.
plot_prediction()
[13]:
# raw predictions are a dictionary from which all kind of information including quantiles can be extracted raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)
[14]:
for idx in range(10): # plot 10 examples best_tft.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True);
Looking at the worst performers, for example in terms of SMAPE, gives us an idea where the model has issues with forecasting reliably. These examples can provide important pointers about how to improve the model. This kind of actuals vs predictions plots are available to all models. Of course, it is also sensible to employ additional metrics, such as MASE, defined in the metrics module. However, for the sake of demonstration, we only use SMAPE here.
SMAPE
metrics
[15]:
# calcualte metric by which to display predictions = best_tft.predict(val_dataloader) mean_losses = SMAPE(reduction="none")(predictions, actuals).mean(1) indices = mean_losses.argsort(descending=True) # sort losses for idx in range(10): # plot 10 examples best_tft.plot_prediction(x, raw_predictions, idx=indices[idx], add_loss_to_title=SMAPE());
Checking how the model performs across different slices of the data allows us to detect weaknesses. Plotted below are the means of predictions vs actuals across each variable divided into 100 bins using the Now, we can directly predict on the generated data using the calculate_prediction_actual_by_variable() and plot_prediction_actual_by_variable() methods. The gray bars denote the frequency of the variable by bin, i.e. are a histogram.
calculate_prediction_actual_by_variable()
plot_prediction_actual_by_variable()
[16]:
predictions, x = best_tft.predict(val_dataloader, return_x=True) predictions_vs_actuals = best_tft.calculate_prediction_actual_by_variable(x, predictions) best_tft.plot_prediction_actual_by_variable(predictions_vs_actuals);
[17]:
best_tft.predict( training.filter(lambda x: (x.agency == "Agency_01") & (x.sku == "SKU_01") & (x.time_idx_first_prediction == 15)), mode="quantiles", )
tensor([[[ 73.3693, 105.9448, 120.0385, 142.9053, 167.0046, 189.0140, 211.1777], [ 60.1067, 96.1256, 112.0637, 129.5548, 155.1736, 176.4882, 190.9185], [ 54.6926, 81.0330, 95.2863, 112.1914, 134.1506, 151.7681, 177.2157], [ 47.4090, 79.6756, 91.1664, 103.3393, 125.8313, 144.5693, 170.5650], [ 43.0128, 75.8046, 88.3507, 101.4471, 125.7367, 143.2664, 170.9305], [ 40.6680, 61.3395, 72.7728, 87.5364, 103.9094, 116.1759, 157.6852]]])
Of course, we can also plot this prediction readily:
[18]:
raw_prediction, x = best_tft.predict( training.filter(lambda x: (x.agency == "Agency_01") & (x.sku == "SKU_01") & (x.time_idx_first_prediction == 15)), mode="raw", return_x=True, ) best_tft.plot_prediction(x, raw_prediction, idx=0);
Because we have covariates in the dataset, predicting on new data requires us to define the known covariates upfront.
[19]:
# select last 24 months from data (max_encoder_length is 24) encoder_data = data[lambda x: x.time_idx > x.time_idx.max() - max_encoder_length] # select last known data point and create decoder data from it by repeating it and incrementing the month # in a real world dataset, we should not just forward fill the covariates but specify them to account # for changes in special days and prices last_data = data[lambda x: x.time_idx == x.time_idx.max()] decoder_data = pd.concat( [last_data.assign(date=lambda x: x.date + pd.offsets.MonthBegin(i)) for i in range(1, max_prediction_length + 1)], ignore_index=True, ) # add time index consistent with "data" decoder_data["time_idx"] = decoder_data["date"].dt.year * 12 + decoder_data["date"].dt.month decoder_data["time_idx"] += encoder_data["time_idx"].max() + 1 - decoder_data["time_idx"].min() # adjust additional time feature(s) decoder_data["month"] = decoder_data.date.dt.month.astype(str).astype("category") # categories have be strings # combine encoder and decoder data new_prediction_data = pd.concat([encoder_data, decoder_data], ignore_index=True)
Now, we can directly predict on the generated data using the predict() method.
[20]:
new_raw_predictions, new_x = best_tft.predict(new_prediction_data, mode="raw", return_x=True) for idx in range(10): # plot 10 examples best_tft.plot_prediction(new_x, new_raw_predictions, idx=idx, show_future_observed=False);
The model has inbuilt interpretation capabilities due to how its architecture is build. Let’s see how that looks. We first calculate interpretations with interpret_output() and plot them subsequently with plot_interpretation().
interpret_output()
plot_interpretation()
[21]:
interpretation = best_tft.interpret_output(raw_predictions, reduction="sum") best_tft.plot_interpretation(interpretation)
{'attention': <Figure size 432x288 with 1 Axes>, 'static_variables': <Figure size 504x270 with 1 Axes>, 'encoder_variables': <Figure size 504x378 with 1 Axes>, 'decoder_variables': <Figure size 504x252 with 1 Axes>}
Unsurprisingly, the past observed volume features as the top variable in the encoder and price related variables are among the top predictors in the decoder.
The general attention patterns seems to be that more recent observations are more important and older ones. This confirms intuition. The average attention is often not very useful - looking at the attention by example is more insightful because patterns are not averaged out.
Partial dependency plots are often used to interpret the model better (assuming independence of features). They can be also useful to understand what to expect in case of simulations and are created with predict_dependency().
predict_dependency()
[22]:
dependency = best_tft.predict_dependency( val_dataloader.dataset, "discount_in_percent", np.linspace(0, 30, 30), show_progress_bar=True, mode="dataframe" )
[23]:
# plotting median and 25% and 75% percentile agg_dependency = dependency.groupby("discount_in_percent").normalized_prediction.agg( median="median", q25=lambda x: x.quantile(0.25), q75=lambda x: x.quantile(0.75) ) ax = agg_dependency.plot(y="median") ax.fill_between(agg_dependency.index, agg_dependency.q25, agg_dependency.q75, alpha=0.3);