{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Demand forecasting with the Temporal Fusion Transformer\n" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "In this tutorial, we will train the :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer` on a very small dataset to demonstrate that it even does a good job on only 20k samples. Generally speaking, it is a large model and will therefore perform much better with more data." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Our example is a demand forecast from the [Stallion kaggle competition](https://www.kaggle.com/utathya/future-volume-prediction).\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths\n", "\n", "os.chdir(\"../../..\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import copy\n", "from pathlib import Path\n", "import warnings\n", "\n", "import lightning.pytorch as pl\n", "from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor\n", "from lightning.pytorch.loggers import TensorBoardLogger\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet\n", "from pytorch_forecasting.data import GroupNormalizer\n", "from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss\n", "from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load data\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "First, we need to transform our time series into a pandas dataframe where each row can be identified with a time step and a time series. Fortunately, most datasets are already in this format. For this tutorial, we will use the [Stallion dataset from Kaggle](https://www.kaggle.com/utathya/future-volume-prediction) describing sales of various beverages. Our task is to make a six-month forecast of the sold volume by stock keeping units (SKU), that is products, sold by an agency, that is a store. There are about 21 000 monthly historic sales records. In addition to historic sales we have information about the sales price, the location of the agency, special days such as holidays, and volume sold in the entire industry.\n", "\n", "The dataset is already in the correct format but misses some important features. Most importantly, we need to add a time index that is incremented by one for each time step. Further, it is beneficial to add date features, which in this case means extracting the month from the date record.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | agency | \n", "sku | \n", "volume | \n", "date | \n", "industry_volume | \n", "soda_volume | \n", "avg_max_temp | \n", "price_regular | \n", "price_actual | \n", "discount | \n", "... | \n", "football_gold_cup | \n", "beer_capital | \n", "music_fest | \n", "discount_in_percent | \n", "timeseries | \n", "time_idx | \n", "month | \n", "log_volume | \n", "avg_volume_by_sku | \n", "avg_volume_by_agency | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
291 | \n", "Agency_25 | \n", "SKU_03 | \n", "0.5076 | \n", "2013-01-01 | \n", "492612703 | \n", "718394219 | \n", "25.845238 | \n", "1264.162234 | \n", "1152.473405 | \n", "111.688829 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "8.835008 | \n", "228 | \n", "0 | \n", "1 | \n", "-0.678062 | \n", "1225.306376 | \n", "99.650400 | \n", "
871 | \n", "Agency_29 | \n", "SKU_02 | \n", "8.7480 | \n", "2015-01-01 | \n", "498567142 | \n", "762225057 | \n", "27.584615 | \n", "1316.098485 | \n", "1296.804924 | \n", "19.293561 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "1.465966 | \n", "177 | \n", "24 | \n", "1 | \n", "2.168825 | \n", "1634.434615 | \n", "11.397086 | \n", "
19532 | \n", "Agency_47 | \n", "SKU_01 | \n", "4.9680 | \n", "2013-09-01 | \n", "454252482 | \n", "789624076 | \n", "30.665957 | \n", "1269.250000 | \n", "1266.490490 | \n", "2.759510 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "0.217413 | \n", "322 | \n", "8 | \n", "9 | \n", "1.603017 | \n", "2625.472644 | \n", "48.295650 | \n", "
2089 | \n", "Agency_53 | \n", "SKU_07 | \n", "21.6825 | \n", "2013-10-01 | \n", "480693900 | \n", "791658684 | \n", "29.197727 | \n", "1193.842373 | \n", "1128.124395 | \n", "65.717978 | \n", "... | \n", "- | \n", "beer_capital | \n", "- | \n", "5.504745 | \n", "240 | \n", "9 | \n", "10 | \n", "3.076505 | \n", "38.529107 | \n", "2511.035175 | \n", "
9755 | \n", "Agency_17 | \n", "SKU_02 | \n", "960.5520 | \n", "2015-03-01 | \n", "515468092 | \n", "871204688 | \n", "23.608120 | \n", "1338.334248 | \n", "1232.128069 | \n", "106.206179 | \n", "... | \n", "- | \n", "- | \n", "music_fest | \n", "7.935699 | \n", "259 | \n", "26 | \n", "3 | \n", "6.867508 | \n", "2143.677462 | \n", "396.022140 | \n", "
7561 | \n", "Agency_05 | \n", "SKU_03 | \n", "1184.6535 | \n", "2014-02-01 | \n", "425528909 | \n", "734443953 | \n", "28.668254 | \n", "1369.556376 | \n", "1161.135214 | \n", "208.421162 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "15.218151 | \n", "21 | \n", "13 | \n", "2 | \n", "7.077206 | \n", "1566.643589 | \n", "1881.866367 | \n", "
19204 | \n", "Agency_11 | \n", "SKU_05 | \n", "5.5593 | \n", "2017-08-01 | \n", "623319783 | \n", "1049868815 | \n", "31.915385 | \n", "1922.486644 | \n", "1651.307674 | \n", "271.178970 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "14.105636 | \n", "17 | \n", "55 | \n", "8 | \n", "1.715472 | \n", "1385.225478 | \n", "109.699200 | \n", "
8781 | \n", "Agency_48 | \n", "SKU_04 | \n", "4275.1605 | \n", "2013-03-01 | \n", "509281531 | \n", "892192092 | \n", "26.767857 | \n", "1761.258209 | \n", "1546.059670 | \n", "215.198539 | \n", "... | \n", "- | \n", "- | \n", "music_fest | \n", "12.218455 | \n", "151 | \n", "2 | \n", "3 | \n", "8.360577 | \n", "1757.950603 | \n", "1925.272108 | \n", "
2540 | \n", "Agency_07 | \n", "SKU_21 | \n", "0.0000 | \n", "2015-10-01 | \n", "544203593 | \n", "761469815 | \n", "28.987755 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "0.000000 | \n", "300 | \n", "33 | \n", "10 | \n", "-18.420681 | \n", "0.000000 | \n", "2418.719550 | \n", "
12084 | \n", "Agency_21 | \n", "SKU_03 | \n", "46.3608 | \n", "2017-04-01 | \n", "589969396 | \n", "940912941 | \n", "32.478910 | \n", "1675.922116 | \n", "1413.571789 | \n", "262.350327 | \n", "... | \n", "- | \n", "- | \n", "- | \n", "15.654088 | \n", "181 | \n", "51 | \n", "4 | \n", "3.836454 | \n", "2034.293024 | \n", "109.381800 | \n", "
10 rows × 31 columns
\n", "\n", " | volume | \n", "date | \n", "industry_volume | \n", "soda_volume | \n", "avg_max_temp | \n", "price_regular | \n", "price_actual | \n", "discount | \n", "avg_population_2017 | \n", "avg_yearly_household_income_2017 | \n", "discount_in_percent | \n", "timeseries | \n", "time_idx | \n", "log_volume | \n", "avg_volume_by_sku | \n", "avg_volume_by_agency | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | \n", "21000.000000 | \n", "21000 | \n", "2.100000e+04 | \n", "2.100000e+04 | \n", "21000.000000 | \n", "21000.000000 | \n", "21000.000000 | \n", "21000.000000 | \n", "2.100000e+04 | \n", "21000.000000 | \n", "21000.000000 | \n", "21000.00000 | \n", "21000.000000 | \n", "21000.000000 | \n", "21000.000000 | \n", "21000.000000 | \n", "
mean | \n", "1492.403982 | \n", "2015-06-16 20:48:00 | \n", "5.439214e+08 | \n", "8.512000e+08 | \n", "28.612404 | \n", "1451.536344 | \n", "1267.347450 | \n", "184.374146 | \n", "1.045065e+06 | \n", "151073.494286 | \n", "10.574884 | \n", "174.50000 | \n", "29.500000 | \n", "2.464118 | \n", "1492.403982 | \n", "1492.403982 | \n", "
min | \n", "0.000000 | \n", "2013-01-01 00:00:00 | \n", "4.130518e+08 | \n", "6.964015e+08 | \n", "16.731034 | \n", "0.000000 | \n", "-3121.690141 | \n", "0.000000 | \n", "1.227100e+04 | \n", "90240.000000 | \n", "0.000000 | \n", "0.00000 | \n", "0.000000 | \n", "-18.420681 | \n", "0.000000 | \n", "0.000000 | \n", "
25% | \n", "8.272388 | \n", "2014-03-24 06:00:00 | \n", "5.090553e+08 | \n", "7.890880e+08 | \n", "25.374816 | \n", "1311.547158 | \n", "1178.365653 | \n", "54.935108 | \n", "6.018900e+04 | \n", "110057.000000 | \n", "3.749628 | \n", "87.00000 | \n", "14.750000 | \n", "2.112923 | \n", "932.285496 | \n", "113.420250 | \n", "
50% | \n", "158.436000 | \n", "2015-06-16 00:00:00 | \n", "5.512000e+08 | \n", "8.649196e+08 | \n", "28.479272 | \n", "1495.174592 | \n", "1324.695705 | \n", "138.307225 | \n", "1.232242e+06 | \n", "131411.000000 | \n", "8.948990 | \n", "174.50000 | \n", "29.500000 | \n", "5.065351 | \n", "1402.305264 | \n", "1730.529771 | \n", "
75% | \n", "1774.793475 | \n", "2016-09-08 12:00:00 | \n", "5.893715e+08 | \n", "9.005551e+08 | \n", "31.568405 | \n", "1725.652080 | \n", "1517.311427 | \n", "272.298630 | \n", "1.729177e+06 | \n", "206553.000000 | \n", "15.647058 | \n", "262.00000 | \n", "44.250000 | \n", "7.481439 | \n", "2195.362302 | \n", "2595.316500 | \n", "
max | \n", "22526.610000 | \n", "2017-12-01 00:00:00 | \n", "6.700157e+08 | \n", "1.049869e+09 | \n", "45.290476 | \n", "19166.625000 | \n", "4925.404000 | \n", "19166.625000 | \n", "3.137874e+06 | \n", "247220.000000 | \n", "226.740147 | \n", "349.00000 | \n", "59.000000 | \n", "10.022453 | \n", "4332.363750 | \n", "5884.717375 | \n", "
std | \n", "2711.496882 | \n", "NaN | \n", "6.288022e+07 | \n", "7.824340e+07 | \n", "3.972833 | \n", "683.362417 | \n", "587.757323 | \n", "257.469968 | \n", "9.291926e+05 | \n", "50409.593114 | \n", "9.590813 | \n", "101.03829 | \n", "17.318515 | \n", "8.178218 | \n", "1051.790829 | \n", "1328.239698 | \n", "