NBeats#

class pytorch_forecasting.models.nbeats.NBeats(stack_types: List[str] = ['trend', 'seasonality'], num_blocks=[3, 3], num_block_layers=[3, 3], widths=[32, 512], sharing: List[int] = [True, True], expansion_coefficient_lengths: List[int] = [3, 7], prediction_length: int = 1, context_length: int = 1, dropout: float = 0.1, learning_rate: float = 0.01, log_interval: int = -1, log_gradient_flow: bool = False, log_val_interval: int | None = None, weight_decay: float = 0.001, loss: MultiHorizonMetric | None = None, reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: ModuleList | None = None, **kwargs)[source]#

Bases: BaseModel

Initialize NBeats Model - use its from_dataset() method if possible.

Based on the article N-BEATS: Neural basis expansion analysis for interpretable time series forecasting. The network has (if used as ensemble) outperformed all other methods including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably the most important benchmark for univariate time series forecasting.

The NHiTS network has recently shown to consistently outperform N-BEATS.

Parameters:
  • stack_types – One of the following values: “generic”, “seasonality” or “trend”. A list of strings of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]

  • num_blocks – The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3]

  • num_block_layers – Number of fully connected layers with ReLu activation per block. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]

  • width – Widths of the fully connected layers with ReLu activation in the blocks. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512] Recommended value for interpretable mode: [256, 2048]

  • sharing – Whether the weights are shared with the other blocks per stack. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False] Recommended value for interpretable mode: [True]

  • expansion_coefficient_length – If the type is “G” (generic), then the length of the expansion coefficient. If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for interpretable mode: [3]

  • prediction_length – Length of the prediction. Also known as ‘horizon’.

  • context_length – Number of time units that condition the predictions. Also known as ‘lookback period’. Should be between 1-10 times the prediction length.

  • backcast_loss_ratio – weight of backcast in comparison to forecast when calculating the loss. A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight.

  • loss – loss to optimize. Defaults to MASE().

  • log_gradient_flow – if to log gradient flow, this takes time and should be only done to diagnose training failures

  • reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10

  • logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])

  • **kwargs – additional arguments to BaseModel.

Methods

forward(x)

Pass forward of network.

from_dataset(dataset, **kwargs)

Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.

log_interpretation(x, out, batch_idx)

Log interpretation of network predictions in tensorboard.

plot_interpretation(x, output, idx[, ax, ...])

Plot interpretation.

step(x, y, batch_idx)

Take training / validation step.

forward(x: Dict[str, Tensor]) Dict[str, Tensor][source]#

Pass forward of network.

Parameters:

x (Dict[str, torch.Tensor]) – input from dataloader generated from TimeSeriesDataSet.

Returns:

output of model

Return type:

Dict[str, torch.Tensor]

classmethod from_dataset(dataset: TimeSeriesDataSet, **kwargs)[source]#

Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.

Parameters:
  • dataset (TimeSeriesDataSet) – dataset where sole predictor is the target.

  • **kwargs – additional arguments to be passed to __init__ method.

Returns:

NBeats

log_interpretation(x, out, batch_idx)[source]#

Log interpretation of network predictions in tensorboard.

plot_interpretation(x: Dict[str, Tensor], output: Dict[str, Tensor], idx: int, ax=None, plot_seasonality_and_generic_on_secondary_axis: bool = False) Figure[source]#

Plot interpretation.

Plot two pannels: prediction and backcast vs actuals and decomposition of prediction into trend, seasonality and generic forecast.

Parameters:
  • x (Dict[str, torch.Tensor]) – network input

  • output (Dict[str, torch.Tensor]) – network output

  • idx (int) – index of sample for which to plot the interpretation.

  • ax (List[matplotlib axes], optional) – list of two matplotlib axes onto which to plot the interpretation. Defaults to None.

  • plot_seasonality_and_generic_on_secondary_axis (bool, optional) – if to plot seasonality and generic forecast on secondary axis in second panel. Defaults to False.

Returns:

matplotlib figure

Return type:

plt.Figure

step(x, y, batch_idx) Dict[str, Tensor][source]#

Take training / validation step.