Interpretable forecasting with N-Beats

[1]:
import os
import warnings

warnings.filterwarnings("ignore")

os.chdir("../../..")
[2]:
import pandas as pd
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, NBeats, Baseline
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_forecasting.metrics import SMAPE

Load data

[3]:
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
data["static"] = 2
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
data.head()
[3]:
series time_idx value static date
0 0 0 -0.000000 2 2020-01-01
1 0 1 0.033712 2 2020-01-02
2 0 2 0.061853 2 2020-01-03
3 0 3 0.099868 2 2020-01-04
4 0 4 0.139527 2 2020-01-05
[4]:
# create dataset and dataloaders
max_encoder_length = 60
max_prediction_length = 20

training_cutoff = data["time_idx"].max() - max_prediction_length

context_length = max_encoder_length
prediction_length = max_prediction_length

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="value",
    categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
    group_ids=["series"],
    # only unknown variable is "value" - and N-Beats can also not take any additional variables
    time_varying_unknown_reals=["value"],
    max_encoder_length=context_length,
    max_prediction_length=prediction_length,
)

validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff+1)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

Calculate baseline error

[5]:
# calculate baseline absolute error
actuals = torch.cat([y for x, y in iter(val_dataloader)])
baseline_predictions = Baseline().predict(val_dataloader)
SMAPE()(baseline_predictions, actuals)
[5]:
tensor(0.4990)

Train network

Find optimal learning rate

[6]:
trainer = pl.Trainer(gpus=0, gradient_clip_val=0.1)
net = NBeats.from_dataset(training, learning_rate=3e-2, weight_decay=1e-2)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[7]:
# find optimal learning rate
res = trainer.lr_find(net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2)
print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()
net.hparams.learning_rate = res.suggestion()

  | Name       | Type       | Params
------------------------------------------
0 | loss       | SMAPE      | 0
1 | net_blocks | ModuleList | 134 K
Saving latest checkpoint..
LR finder stopped early due to diverging loss.
suggested learning rate: 0.007413102413009173
../_images/tutorials_ar_11_4.png

Fit model

[8]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    weights_summary="top",
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,
)


net = NBeats.from_dataset(training, learning_rate=1.2e-2, log_interval=10, log_val_interval=1, weight_decay=1e-2)

trainer.fit(
    net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name       | Type       | Params
------------------------------------------
0 | loss       | SMAPE      | 0
1 | net_blocks | ModuleList | 134 K

Saving latest checkpoint..

[8]:
1

Evaluate Results

[9]:
# best_model_path = trainer.checkpoint_callback.best_model_path
best_model_path = "/Users/beitnerjan/Documents/Github/temporal_fusion_transformer_pytorch/lightning_logs/version_212/checkpoints/epoch=19.ckpt"
best_model = NBeats.load_from_checkpoint(best_model_path)
print(best_model_path)
/Users/beitnerjan/Documents/Github/temporal_fusion_transformer_pytorch/lightning_logs/version_212/checkpoints/epoch=19.ckpt

We calculate the error which is approximately half of the baseline error

[10]:
actuals = torch.cat([y for x, y in iter(val_dataloader)])
predictions = best_model.predict(val_dataloader)
(actuals - predictions).abs().mean()
[10]:
tensor(0.2058)
[11]:
raw_predictions, x = best_model.predict(val_dataloader, mode="raw", return_x=True)
[12]:
for idx in range(10):
    best_model.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True);
../_images/tutorials_ar_19_0.png
../_images/tutorials_ar_19_1.png
../_images/tutorials_ar_19_2.png
../_images/tutorials_ar_19_3.png
../_images/tutorials_ar_19_4.png
../_images/tutorials_ar_19_5.png
../_images/tutorials_ar_19_6.png
../_images/tutorials_ar_19_7.png
../_images/tutorials_ar_19_8.png
../_images/tutorials_ar_19_9.png

Interpret model

[13]:
for idx in range(10):
    best_model.plot_interpretation(x, raw_predictions, idx=idx)
;
[13]:
''
../_images/tutorials_ar_21_1.png
../_images/tutorials_ar_21_2.png
../_images/tutorials_ar_21_3.png
../_images/tutorials_ar_21_4.png
../_images/tutorials_ar_21_5.png
../_images/tutorials_ar_21_6.png
../_images/tutorials_ar_21_7.png
../_images/tutorials_ar_21_8.png
../_images/tutorials_ar_21_9.png
../_images/tutorials_ar_21_10.png
[ ]: