{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rzVbXsEBxnF-" }, "source": [ "# Example Notebook for a basic vignette for `pytorch-forecasting v2` Model Training and Inference" ] }, { "cell_type": "markdown", "metadata": { "id": "yt0uZV7Px-40" }, "source": [ "
\n", ":warning: The \"Data Pipeline\" showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice. This notebook serves as a basic demonstration of the intended workflow and is not recommended for use in production environments. Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "r15UunnLoxnK" }, "source": [ "In this notebook, we demonstrate how to train and evaluate the **Temporal Fusion Transformer (TFT)** using the new `TimeSeries` and `DataModule` API from the v2 pipeline.\n", "We can do this in 2 ways:\n", "1. **High-level package API:**\n", "\n", " This approach handles data loading, dataloader creation, and model training internally. It provides a simple, `scikit-learn`-like `fit` → `predict` workflow.\n", " Users can still configure key training options (such as the `trainer`, callbacks, and training parameters) but cannot plug in fully custom `trainer` implementations or override internal pipeline logic.\n", "\n", "2. **Low-level 3-stage pipeline**:\n", "This involves explicitly constructing:\n", " * a `TimeSeries` object\n", "\n", " * a `DataModule`\n", "\n", " * the model (e.g., `TFT`)\n", " \n", " This workflow is ideal if you need custom setups such as custom trainers, callbacks, or advanced data preprocessing.\n", " It requires a deeper understanding of how the three layers (TimeSeries, DataModule, and the model) interact, but offers maximum flexibility." ] }, { "cell_type": "markdown", "metadata": { "id": "QyMFNk4MyY_b" }, "source": [ "## Create Synthetic data\n", "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now **the pipeline assumes the data to be numerical only**." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "RkgOT4kiy_RU" }, "outputs": [], "source": [ "from pytorch_forecasting.data.examples import load_toydata" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "WX-FRdusJSVN", "outputId": "2ad916b8-2fd9-4318-afb1-2bda84d284d7" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6712252870750063,\n \"min\": -1.2780952045426857,\n \"max\": 1.3163602917006327,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.19335967827533446,\n 0.8492207493147326,\n -0.9687640491099185\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6753351884449413,\n \"min\": -1.2780952045426857,\n \"max\": 1.3163602917006327,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6981263626070341,\n 0.7052787051636003,\n -0.861386757323439\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2792423704109133,\n \"min\": 0.031153133884698536,\n \"max\": 0.9662188410416612,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.24602577096925082,\n 0.8680231736929984,\n 0.6913124004679789\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", "type": "dataframe", "variable_name": "data_df" }, "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.0306430.14828001.0000000.0392130
1010.1482800.43302900.9950040.0392130
2020.4330290.74251100.9800670.0392130
3030.7425110.72927000.9553360.0392130
4040.7292700.62860400.9210610.0392130
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "text/plain": [ " series_id time_idx x y category future_known_feature \\\n", "0 0 0 -0.030643 0.148280 0 1.000000 \n", "1 0 1 0.148280 0.433029 0 0.995004 \n", "2 0 2 0.433029 0.742511 0 0.980067 \n", "3 0 3 0.742511 0.729270 0 0.955336 \n", "4 0 4 0.729270 0.628604 0 0.921061 \n", "\n", " static_feature static_feature_cat \n", "0 0.039213 0 \n", "1 0.039213 0 \n", "2 0.039213 0 \n", "3 0.039213 0 \n", "4 0.039213 0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_series = 100 # Number of individual time series to generate\n", "seq_length = 50 # Length of each time series\n", "data_df = load_toydata(num_series, seq_length)\n", "data_df.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "_8TgLH82runO" }, "source": [ "## High-level API\n" ] }, { "cell_type": "markdown", "metadata": { "id": "A1cqKCRur4oj" }, "source": [ "### Steps\n", "* Create the `TimeSeries` object\n", "* Create `configs` for model, `datamodule`, `trainer` etc.\n", "* Create the `model_pkg` object\n", "* perform `pkg.fit` and `pkg.predict`.\n", "\n", "### Create Dataset object\n", "\n", "`TimeSeries` returns the raw data in terms of tensors .\n", "\n", "---\n", "\n", "`TimeSeries` dataset's Key arguments:\n", "- `data`: DataFrame with sequence data.\n", "- `time`: integer typed column denoting the time index within `data`.\n", "- `target`: Column(s) in `data` denoting the forecasting target.\n", "- `group`: List of column names identifying a time series instance within `data`.\n", "- `num`: List of numerical features.\n", "- `cat`: List of categorical features.\n", "- `known`: Features known in future\n", "- `unknown`: Features not known in the future\n", "- `static`: List of variables that do not change over time,\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "u8OPR0HntXqR" }, "outputs": [], "source": [ "from pytorch_forecasting.data.timeseries import TimeSeries" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6a_oy4VjtrHQ", "outputId": "54678fb8-864e-4f32-eeb9-83697946a3e5" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/content/pytorch-forecasting/pytorch_forecasting/data/timeseries/_timeseries_v2.py:105: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", " warn(\n" ] } ], "source": [ "# create `TimeSeries` dataset that returns the raw data in terms of tensors\n", "dataset = TimeSeries(\n", " data=data_df,\n", " time=\"time_idx\",\n", " target=\"y\",\n", " group=[\"series_id\"],\n", " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", " cat=[\"category\", \"static_feature_cat\"],\n", " known=[\"future_known_feature\"],\n", " unknown=[\"x\", \"category\"],\n", " static=[\"static_feature\", \"static_feature_cat\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "EoS6W9zh6wCj" }, "source": [ "### Create the configs\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "MKPXPUcC5dTY" }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "from pytorch_forecasting.data.encoders import (\n", " NaNLabelEncoder,\n", " TorchNormalizer,\n", ")\n", "from pytorch_forecasting.metrics import MAE, SMAPE" ] }, { "cell_type": "markdown", "metadata": { "id": "WYl9-oZz6nk6" }, "source": [ "Here we use `EncoderDecoderTimeSeriesDataModule`\n", "\n", "\n", "`EncoderDecoderTimeSeriesDataModule` key arguments:\n", "- `time_series_dataset`: `TimeSeries` dataset instance\n", "- `max_encoder_length` : Maximum length of the encoder input sequence.\n", "- `max_prediction_length` : Maximum length of the decoder output sequence.\n", "- `batch_size` : Batch size for DataLoader.\n", "- `categorical_encoders` : Dictionary of categorical encoders.\n", "- `scalers` : Dictionary of feature scalers.\n", "- `target_normalizer`: Normalizer for the target variable." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "YGMShzfyttp_" }, "outputs": [], "source": [ "datamodule_cfg = dict(\n", " max_encoder_length=30,\n", " max_prediction_length=1,\n", " batch_size=32,\n", " categorical_encoders={\n", " \"category\": NaNLabelEncoder(add_nan=True),\n", " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", " },\n", " scalers={\n", " \"x\": StandardScaler(),\n", " \"future_known_feature\": StandardScaler(),\n", " \"static_feature\": StandardScaler(),\n", " },\n", " target_normalizer=TorchNormalizer(),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Pi5Qkznh6t3y" }, "source": [ "We would use `TFT` model in this tutorial" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "q6Thm13ct7OV" }, "outputs": [], "source": [ "model_cfg = dict(\n", " loss=MAE(),\n", " logging_metrics=[MAE(), SMAPE()],\n", " optimizer=\"adam\",\n", " optimizer_params={\"lr\": 1e-3},\n", " lr_scheduler=\"reduce_lr_on_plateau\",\n", " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", " hidden_size=64,\n", " num_layers=2,\n", " attention_head_size=4,\n", " dropout=0.1,\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "Stfuc_xCuON6" }, "outputs": [], "source": [ "trainer_cfg = dict(\n", " max_epochs=5,\n", " accelerator=\"auto\",\n", " devices=1,\n", " enable_progress_bar=True,\n", " log_every_n_steps=10,\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "XS_ND8UAubdN" }, "outputs": [], "source": [ "from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import (\n", " TFT_pkg_v2,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "6yoqI8907DG4" }, "source": [ "### Create the `model_pkg` object\n", "\n", "This `pkg` class acts as a wrapper around the whole ML pipeline in `pytorch-forecasting` and we can simply just define the `pkg` class and then use `pkg.fit` and `pkg.predict` to perform the \"fit\", \"predict\" mechanisms." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aOxng4Rguwj2", "outputId": "2c50fcad-f990-4aae-f0bb-5dbdd6a87377" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'loss': MAE(), 'logging_metrics': [MAE(), SMAPE()], 'optimizer': 'adam', 'optimizer_params': {'lr': 0.001}, 'lr_scheduler': 'reduce_lr_on_plateau', 'lr_scheduler_params': {'mode': 'min', 'factor': 0.1, 'patience': 10}, 'hidden_size': 64, 'num_layers': 2, 'attention_head_size': 4, 'dropout': 0.1}\n" ] } ], "source": [ "model_pkg = TFT_pkg_v2(\n", " model_cfg=model_cfg,\n", " trainer_cfg=trainer_cfg,\n", " datamodule_cfg=datamodule_cfg,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 976, "referenced_widgets": [ "4ecdea6764d145118ab53e59451d2b0c", "e7b969aa6d8e433d9aeeac4357bc425d", "42707b895305490b82cd644250e689fa", "19c3c106d5a9489cae445a0b5fc88183", "a6e5908902eb40e997e6086287f28f2a", "f23d99cc4f01426eb4fc8d41fc8f4b16", "df8d7458b0fa4f508d4ce357fc95c609", "464464a47d604d708be37e18edec4810", "c7ca9662eae04999b21de91c183a0856", "e82943533ad54539a777d6adae271d0f", "e0fa236745204d5ba2dbc1ff2c51f1a2", "c52bb8ff12db4df3a05cf1da7b5470f7", "273fb7ccddeb476f9c76bd1be44a6ae0", "922dfb34c7494b20a874e294c07447e0", "96eda6cd2fbc47b9a29d2cf176332058", "d2ed70b544924436b185f79d0cd90862", "aff8baef21494ccf99bf092fc3daaae0", "640d3876c2ce49de9757962b5b5b0e32", "f99327891eeb424b9df4fe04a6bedcfb", "04804ca4c425464db2754abf3cf95568", "386a2097c02f4af6ba239e385f4b0b47", "7e95bdbc3a6c46bf8b1e96bdecfa7303", "572629c64cfd46a983f1a8c6483a2cf1", "61f356a200c5470db777e7e0f9e8c520", "4905c6d809274aa39a984e3e458fc89a", "b16522c88ebb435f96c05315ef91ebbb", "65cce98698644285896363e509ae6139", "a4a4c07f117e46989cb4877a5d2dd9e0", "3f8cc20607db40c7acb4110feba9ab0d", "3c5c1f55d5a64838b94fc0fbd85097c6", "4f5cce37b6ac430e85757dd06b06953a", "942645c506ea436fb598455d84c8a970", "969a3ddfaed84150944d697307ababe4", "955c5e9c139148a1a352d17202fe097f", "f0fbcbcf02e443bc99a469cf4c7f8131", "72fb23e179594f35a68418e2e0ee65fc", "7808bf48e45940cfa0d4bccae784d730", "ed0358a45ec14ce687fc02904a815e38", "67a3d79f1b2e4e03b9a564286c04d5d4", "27fd0da590314bb68dfac5b7c72d6584", "15e539660a2547f49fb2cf8a6143f5fa", "c46b831b37a347868f1d35d0dbbfd923", "3f985da9d6a245c5b54dbb47926a4fd4", "5ab64c01efb84e75af1a8aaf6675f5bf", "8d0747756fd2434399ae8d233a82d607", "f77d800d097b494ca3e945abdaedd75c", "05a14444ea4043dea69a4e7185e66cb1", "b775518f409449928c3211260d7223c0", "2db9a1e74ad14139af235f1a2a146e0a", "2bcdfdf1b12c495aa8b425c88fcfbd1b", "d71a01b309e948239a16062097ee76a2", "006cdf49ce55411bb072c2670d87773f", "5b3be082628244948284a40bea451ff1", "bf96e25bd5d64892a329b624961abeb8", "b5e879fa1fee4d0ba30ac5af07d1d8c5", "78f2d725dcb34deca5407c277c384d8d", "e1dc997d76d54eb9a9245530a60c2cd9", "8fa00a6415b74091a012bc5fff543f42", "c74c472cdf174a28a4ca3fed1b312332", "409d65e2b79f49318c580d9835ffc29c", "568141e0b45f44ff9b497c6474d8019f", "1a889d0b55bb4e6d80d5f170297a6262", "679072fe36f5404588879eab670e01e2", "ab72364dc8cf433e907c40df3e7be9e9", "d56c627297bd435fbbc60317066084f9", "6a62c54d7d7b4f689dc31e57aaa20411", "3ab523d60fd249a7bcece32280872abd", "ef8e78e4f8a248dcbbc2ea3e464c5922", "56685ebbfd244154bd1829dec6f0db0b", "dc5f9a923d27492cb382691ec01a1ddc", "58995b1bd1c24433a3aca0ab53c6b8bf", "9082a1b6eb3a4d14b4c92eedec1c2404", "13eb0c5265ad48d08b3f8e46a55896f0", "f91e202108684aa9af76fdf3e9d83206", "f617590dcd184fd99f98515940ac85af", "33a1e5f21b694e3bb62d2c0d73aa65e3", "8c8c8832e16c4d489e0df7514dc78f6a" ] }, "id": "c27Qj4QAvFwx", "outputId": "21bbd594-d92e-498b-bd02-71829295c483" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/content/pytorch-forecasting/pytorch_forecasting/data/data_module.py:129: UserWarning: EncoderDecoderTimeSeriesDataModule is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", " warn(\n", "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:64: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", " warn(\n", "INFO: GPU available: True (cuda), used: True\n", "INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True\n", "INFO: TPU available: False, using: 0 TPU cores\n", "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", "INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "INFO: \n", " | Name | Type | Params | Mode \n", "---------------------------------------------------------------------\n", "0 | loss | MAE | 0 | train\n", "1 | encoder_var_selection | Sequential | 709 | train\n", "2 | decoder_var_selection | Sequential | 193 | train\n", "3 | static_context_linear | Linear | 192 | train\n", "4 | lstm_encoder | LSTM | 51.5 K | train\n", "5 | lstm_decoder | LSTM | 50.4 K | train\n", "6 | self_attention | MultiheadAttention | 16.6 K | train\n", "7 | pre_output | Linear | 4.2 K | train\n", "8 | output_layer | Linear | 65 | train\n", "---------------------------------------------------------------------\n", "123 K Trainable params\n", "0 Non-trainable params\n", "123 K Total params\n", "0.495 Total estimated model params size (MB)\n", "18 Modules in train mode\n", "0 Modules in eval mode\n", "INFO:lightning.pytorch.callbacks.model_summary:\n", " | Name | Type | Params | Mode \n", "---------------------------------------------------------------------\n", "0 | loss | MAE | 0 | train\n", "1 | encoder_var_selection | Sequential | 709 | train\n", "2 | decoder_var_selection | Sequential | 193 | train\n", "3 | static_context_linear | Linear | 192 | train\n", "4 | lstm_encoder | LSTM | 51.5 K | train\n", "5 | lstm_decoder | LSTM | 50.4 K | train\n", "6 | self_attention | MultiheadAttention | 16.6 K | train\n", "7 | pre_output | Linear | 4.2 K | train\n", "8 | output_layer | Linear | 65 | train\n", "---------------------------------------------------------------------\n", "123 K Trainable params\n", "0 Non-trainable params\n", "123 K Total params\n", "0.495 Total estimated model params size (MB)\n", "18 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ecdea6764d145118ab53e59451d2b0c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00