In this tutorial, we will train the Temporal Fusion Transformer on a very small dataset to demonstrate that it even does a good job on only 20k samples. Generally speaking, it will perform much better with more data.
Our example is a demand forecast from the Stallion kaggle competition.
[1]:
import os import warnings warnings.filterwarnings("ignore") os.chdir("../../..")
[2]:
import warnings from pathlib import Path import pandas as pd import numpy as np import torch import copy import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger from pytorch_lightning.loggers import TensorBoardLogger from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline from pytorch_forecasting.data import GroupNormalizer from pytorch_forecasting.metrics import PoissonLoss, QuantileLoss, SMAPE from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
[3]:
from pytorch_forecasting.data.examples import get_stallion_data data = get_stallion_data() # add time index data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month data["time_idx"] -= data["time_idx"].min() # add additional features data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings data["log_volume"] = np.log(data.volume + 1e-8) data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean") data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean") # we want to encode special days as one variable and thus need to first reverse one-hot encoding special_days = [ "easter_day", "good_friday", "new_year", "christmas", "labor_day", "independence_day", "revolution_day_memorial", "regional_games", "fifa_u_17_world_cup", "football_gold_cup", "beer_capital", "music_fest", ] data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category") data.sample(10, random_state=521)
10 rows × 31 columns
[4]:
data.describe()
[5]:
max_prediction_length = 6 max_encoder_length = 24 training_cutoff = data["time_idx"].max() - max_prediction_length training = TimeSeriesDataSet( data[lambda x: x.time_idx <= training_cutoff], time_idx="time_idx", target="volume", group_ids=["agency", "sku"], min_encoder_length=0, # allow encoder lengths from 0 to max_prediction_length max_encoder_length=max_encoder_length, min_prediction_length=1, max_prediction_length=max_prediction_length, static_categoricals=["agency", "sku"], static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], time_varying_known_categoricals=["special_days", "month"], variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"], time_varying_unknown_categoricals=[], time_varying_unknown_reals=[ "volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp", "avg_volume_by_agency", "avg_volume_by_sku", ], target_normalizer=GroupNormalizer( groups=["agency", "sku"], coerce_positive=1.0 ), # use softplus with beta=1.0 and normalize by group add_relative_time_idx=True, add_target_scales=True, add_encoder_length=True, ) # create validation set (predict=True) which means to predict the last max_prediction_length points in time for each series validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True) # create dataloaders for model batch_size = 128 # set this between 32 to 128 train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
[6]:
# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history actuals = torch.cat([y for x, y in iter(val_dataloader)]) baseline_predictions = Baseline().predict(val_dataloader) (actuals - baseline_predictions).abs().mean().item()
293.0088195800781
[7]:
# configure network and trainer pl.seed_everything(42) trainer = pl.Trainer( gpus=0, # clipping gradients is a hyperparameter and important to prevent divergance # of the gradient for recurrent neural networks gradient_clip_val=0.1, ) tft = TemporalFusionTransformer.from_dataset( training, # not meaningful for finding the learning rate but otherwise very important learning_rate=0.03, hidden_size=16, # most important hyperparameter apart from learning rate # number of attention heads. Set to up to 4 for large datasets attention_head_size=1, dropout=0.1, # between 0.1 and 0.3 are good values hidden_continuous_size=8, # set to <= hidden_size output_size=7, # 7 quantiles by default loss=QuantileLoss(), # reduce learning rate if no improvement in validation loss after x epochs reduce_on_plateau_patience=4, ) print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
GPU available: False, used: False TPU available: False, using: 0 TPU cores
Number of parameters in network: 29.6k
[8]:
# find optimal learning rate res = trainer.lr_find( tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, max_lr=10., min_lr=1e-6, ) print(f"suggested learning rate: {res.suggestion()}") fig = res.plot(show=True, suggest=True) fig.show()
| Name | Type | Params ---------------------------------------------------------------------------------------- 0 | loss | QuantileLoss | 0 1 | input_embeddings | ModuleDict | 1 K 2 | prescalers | ModuleDict | 256 3 | static_variable_selection | VariableSelectionNetwork | 3 K 4 | encoder_variable_selection | VariableSelectionNetwork | 8 K 5 | decoder_variable_selection | VariableSelectionNetwork | 2 K 6 | static_context_variable_selection | GatedResidualNetwork | 1 K 7 | static_context_initial_hidden_lstm | GatedResidualNetwork | 1 K 8 | static_context_initial_cell_lstm | GatedResidualNetwork | 1 K 9 | static_context_enrichment | GatedResidualNetwork | 1 K 10 | lstm_encoder | LSTM | 2 K 11 | lstm_decoder | LSTM | 2 K 12 | post_lstm_gate_encoder | GatedLinearUnit | 544 13 | post_lstm_add_norm_encoder | AddNorm | 32 14 | static_enrichment | GatedResidualNetwork | 1 K 15 | multihead_attn | InterpretableMultiHeadAttention | 1 K 16 | post_attn_gate_norm | GateAddNorm | 576 17 | pos_wise_ff | GatedResidualNetwork | 1 K 18 | pre_output_gate_norm | GateAddNorm | 576 19 | output_layer | Linear | 119
Saving latest checkpoint.. LR finder stopped early due to diverging loss.
suggested learning rate: 5.888436553555889e-06
[9]:
# configure network and trainer early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") lr_logger = LearningRateLogger() # log the learning rate logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard trainer = pl.Trainer( max_epochs=30, gpus=0, weights_summary="top", gradient_clip_val=0.1, early_stop_callback=early_stop_callback, limit_train_batches=30, # coment in for training, running valiation every 30 batches # fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs callbacks=[lr_logger], logger=logger, ) tft = TemporalFusionTransformer.from_dataset( training, learning_rate=0.03, hidden_size=16, attention_head_size=1, dropout=0.1, hidden_continuous_size=8, output_size=7, # 7 quantiles by default loss=QuantileLoss(), log_interval=30, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches reduce_on_plateau_patience=4, ) print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
We can now fit the network and monitor its training with tensorboard. Execute tensorboard --logdir=lightning_logs to bring up the tensorboard dashboard to visualize training.
tensorboard --logdir=lightning_logs
[10]:
# fit network trainer.fit( tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, )
Saving latest checkpoint..
1
[11]:
# load the best model according to the validation loss # (given that we use early stopping, this is not necessarily the last epoch) best_model_path = trainer.checkpoint_callback.best_model_path best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
Our model performs better than the baseline model (given the noisy data, this is not trivial)
[12]:
# calcualte mean absolute error on validation set actuals = torch.cat([y for x, y in iter(val_dataloader)]) predictions = best_tft.predict(val_dataloader) (actuals - predictions).abs().mean()
tensor(253.2648)
We can now also look at sample predictions directly and even evaluate at which input samples the model attends most to make a forecast.
[13]:
# raw predictions are a dictionary from which all kind of information including quantiles can be extracted raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)
[14]:
for idx in range(10): best_tft.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True);
Calculating where our model performs poorly, we see that those are really difficult cases to forecast.
[15]:
# calcualte metric by which to display predictions = best_tft.predict(val_dataloader) mean_losses = SMAPE(reduction="none")(predictions, actuals).mean(1) indices = mean_losses.argsort(descending=True) # sort losses for idx in range(10): best_tft.plot_prediction(x, raw_predictions, idx=indices[idx], add_loss_to_title=SMAPE());
Checking how the model performs across different slices of the data allows us to detect weaknesses.
[16]:
predictions, x = best_tft.predict(val_dataloader, return_x=True) predictions_vs_actuals = best_tft.calculate_prediction_actual_by_variable(x, predictions) best_tft.plot_prediction_actual_by_variable(predictions_vs_actuals);
The model has inbuilt interpretation capabilities due to how its architecture is build. Let’s see how that looks.
[17]:
interpretation = best_tft.interpret_output(raw_predictions, reduction="sum") best_tft.plot_interpretation(interpretation)
{'attention': <Figure size 432x288 with 1 Axes>, 'static_variables': <Figure size 504x270 with 1 Axes>, 'encoder_variables': <Figure size 504x378 with 1 Axes>, 'decoder_variables': <Figure size 504x252 with 1 Axes>}
Partial dependency plots are often used to interpret the model better (assuming independence of features). They can be also useful to understand what to expect in case of simulations.
[28]:
dependency = best_tft.predict_dependency(val_dataloader.dataset, "discount_in_percent", np.linspace(0, 30, 30), show_progress_bar=True, mode="dataframe")
[30]:
# plotting median and 25% and 75% percentile agg_dependency = dependency.groupby("discount_in_percent").normalized_prediction.agg(median="median", q25=lambda x: x.quantile(.25), q75=lambda x: x.quantile(.75)) ax = agg_dependency.plot(y="median") ax.fill_between(agg_dependency.index, agg_dependency.q25, agg_dependency.q75, alpha=.3);
[ ]: