{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interpretable forecasting with N-Beats" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "os.chdir(\"../../..\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import EarlyStopping\n", "import torch\n", "\n", "from pytorch_forecasting import Baseline, NBeats, TimeSeriesDataSet\n", "from pytorch_forecasting.data import NaNLabelEncoder\n", "from pytorch_forecasting.data.examples import generate_ar_data\n", "from pytorch_forecasting.metrics import SMAPE" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We generate a synthetic dataset to demonstrate the network's capabilities. The data consists of a quadratic trend and a seasonality component." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | series | \n", "time_idx | \n", "value | \n", "static | \n", "date | \n", "
---|---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "-0.000000 | \n", "2 | \n", "2020-01-01 | \n", "
1 | \n", "0 | \n", "1 | \n", "-0.046501 | \n", "2 | \n", "2020-01-02 | \n", "
2 | \n", "0 | \n", "2 | \n", "-0.097796 | \n", "2 | \n", "2020-01-03 | \n", "
3 | \n", "0 | \n", "3 | \n", "-0.144397 | \n", "2 | \n", "2020-01-04 | \n", "
4 | \n", "0 | \n", "4 | \n", "-0.177954 | \n", "2 | \n", "2020-01-05 | \n", "