Source code for pytorch_forecasting.models.nbeats._nbeatskan

"""
N-Beats model with KAN blocks for timeseries forecasting without covariates.
"""

from typing import Optional

import torch
from torch import nn

from pytorch_forecasting.layers._nbeats._blocks import (
    NBEATSGenericBlockKAN,
    NBEATSSeasonalBlockKAN,
    NBEATSTrendBlockKAN,
)
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric
from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter


[docs] class NBeatsKAN(NBeatsAdapter): """ Initialize NBeatsKAN Model - use its :py:meth:`~from_dataset` method if possible. Based on the article `N-BEATS: Neural basis expansion analysis for interpretable time series forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if used as ensemble) outperformed all other methods including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably the most important benchmark for univariate time series forecasting. The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform N-BEATS. Parameters ---------- stack_types : list of str One of the following values: “generic”, “seasonality" or “trend". A list of strings of length 1 or 'num_stacks'. Default and recommended value for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]. num_blocks : list of int The number of blocks per stack. A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] num_block_layers : list of int Number of fully connected layers with ReLu activation per block. A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]. widths : list of int Widths of the fully connected layers with ReLu activation in the blocks. A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [512]. Recommended value for interpretable mode: [256, 2048] sharing : list of bool Whether the weights are shared with the other blocks per stack. A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [False]. Recommended value for interpretable mode: [True]. expansion_coefficient_lengths : list of int If the type is “G” (generic), then the length of the expansion coefficient. If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. A list of ints of length 1 or 'num_stacks'. Default value for generic mode: [32] Recommended value for interpretable mode: [3] prediction_length : int Length of the prediction. Also known as 'horizon'. context_length : int Number of time units that condition the predictions. Also known as 'lookback period'. Should be between 1-10 times the prediction length. backcast_loss_ratio : float Weight of backcast in comparison to forecast when calculating the loss. A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight. loss : MultiHorizonMetric Loss to optimize. Defaults to MASE(). log_gradient_flow : bool If to log gradient flow, this takes time and should be only done to diagnose training failures. reduce_on_plateau_patience : int Patience after which learning rate is reduced by a factor of 10 logging_metrics : nn.ModuleList of MultiHorizonMetric List of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) num : int Parameter for KAN layer. the number of grid intervals = G. Default: 5. k : int Parameter for KAN layer. the order of piecewise polynomial. Default: 3. noise_scale : float Parameter for KAN layer. the scale of noise injected at initialization. Default: 0.1. scale_base_mu : float Parameter for KAN layer. the scale of the residual function b(x) is initialized to be N(scale_base_mu, scale_base_sigma^2). Default: 0.0. scale_base_sigma : float Parameter for KAN layer. the scale of the residual function b(x) is initialized to be N(scale_base_mu, scale_base_sigma^2). Default: 1.0. scale_sp : float Parameter for KAN layer. the scale of the base function spline(x). Default: 1.0. base_fun : callable Parameter for KAN layer. residual function b(x). Default: None. grid_eps : float Parameter for KAN layer. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02. grid_range : list of int Parameter for KAN layer. list/np.array of shape (2,). setting the range of grids. Default: None. sp_trainable : bool Parameter for KAN layer. If true, scale_sp is trainable. Default: True. sb_trainable : bool Parameter for KAN layer. If true, scale_base is trainable. Default: True. sparse_init : bool Parameter for KAN layer. if sparse_init = True, sparse initialization is applied. Default: False. **kwargs Additional arguments to :py:class:`~BaseModel`. Examples -------- See the full example in: `examples/nbeats_with_kan.py` Notes -------- The KAN blocks are based on the Kolmogorov-Arnold representation theorem and replace fixed MLP edge weights with learnable univariate spline functions. This allows KAN-augmented N-BEATS to better capture complex patterns, improve interpretability, and achieve parameter efficiency. Additionally, when applied in a doubly-residual adversarial framework, the model excels at zero-shot time-series forecasting across markets. Key differences from original N-BEATS: - MLP layers are replaced by KAN layers with spline-based edge functions. - Each weight is a trainable function, not a scalar. - Enables visualization of learned functions and better domain adaptation. - Yields improved accuracy and interpretability with fewer parameters. References ---------- .. [1] Z. Liu et al. (2024), “KAN: Kolmogorov-Arnold Networks” propose replacing MLP weights with spline-based learnable edge functions, enabling improved accuracy, interpretability, and scaling behavior compared to standard MLPs. .. [2] A. Bhattacharya & N. Haq (2024), “Zero Shot Time Series Forecasting Using Kolmogorov Arnold Networks” incorporate KAN layers into a doubly-residual N-BEATS architecture with adversarial domain adaptation, achieving strong zero-shot cross-market electricity price forecasting performance. """ # noqa: E501 @classmethod def _pkg(cls): """Package for the model.""" from pytorch_forecasting.models.nbeats._nbeatskan_pkg import NBeatsKAN_pkg return NBeatsKAN_pkg def __init__( self, stack_types: list[str] | None = None, num_blocks: list[int] | None = None, num_block_layers: list[int] | None = None, widths: list[int] | None = None, sharing: list[bool] | None = None, expansion_coefficient_lengths: list[int] | None = None, prediction_length: int = 1, context_length: int = 1, dropout: float = 0.1, learning_rate: float = 1e-2, log_interval: int = -1, log_gradient_flow: bool = False, log_val_interval: int = None, weight_decay: float = 1e-3, loss: MultiHorizonMetric = None, reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: nn.ModuleList = None, num: int = 5, k: int = 3, noise_scale: float = 0.5, scale_base_mu: float = 0.0, scale_base_sigma: float = 1.0, scale_sp: float = 1.0, base_fun: callable = None, grid_eps: float = 0.02, grid_range: list[int] = None, sp_trainable: bool = True, sb_trainable: bool = True, sparse_init: bool = False, **kwargs, ): if base_fun is None: base_fun = torch.nn.SiLU() if grid_range is None: grid_range = [-1, 1] if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: sharing = [True, True] if widths is None: widths = [32, 512] if num_block_layers is None: num_block_layers = [3, 3] if num_blocks is None: num_blocks = [3, 3] if stack_types is None: stack_types = ["trend", "seasonality"] if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) # Bundle KAN parameters into a dictionary kan_params = { "num": num, "k": k, "noise_scale": noise_scale, "scale_base_mu": scale_base_mu, "scale_base_sigma": scale_base_sigma, "scale_sp": scale_sp, "base_fun": base_fun, "grid_eps": grid_eps, "grid_range": grid_range, "sp_trainable": sp_trainable, "sb_trainable": sb_trainable, "sparse_init": sparse_init, } self.kan_params = kan_params # setup stacks self.net_blocks = nn.ModuleList() for stack_id, stack_type in enumerate(stack_types): for _ in range(num_blocks[stack_id]): if stack_type == "generic": net_block = NBEATSGenericBlockKAN( units=self.hparams.widths[stack_id], thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, dropout=dropout, **self.kan_params, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlockKAN( units=self.hparams.widths[stack_id], num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, min_period=expansion_coefficient_lengths[stack_id], dropout=dropout, **self.kan_params, ) elif stack_type == "trend": net_block = NBEATSTrendBlockKAN( units=self.hparams.widths[stack_id], thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, dropout=dropout, **self.kan_params, ) else: raise ValueError(f"Unknown stack type {stack_type}") self.net_blocks.append(net_block)
[docs] def update_kan_grid(self): """ Updates grid of KAN layers when using KAN layers in NBEATSBlock. WARNING: This relies on 'self.outputs' stored during the last forward pass. Ensure this is called immediately after a TRAINING forward pass. """ if not self.training: return for block in self.net_blocks: # updation logic taken from # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 for i, layer in enumerate(block.fc): # update basis KAN layers' grid layer.update_grid_from_samples(block.outputs[i]) # update theta backward and theta forward KAN layers' grid block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1])