NBeatsAdapter#

class pytorch_forecasting.models.nbeats._nbeats_adapter.NBeatsAdapter(**kwargs)[source]#

Bases: BaseModel

Initialize NBeats Adapter.

Parameters:

**kwargs – additional arguments to BaseModel.

BaseModel for timeseries forecasting from which to inherit from

Parameters:
  • log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.

  • log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.

  • learning_rate (float, optional) – Learning rate. Defaults to 1e-3.

  • log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.

  • loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().

  • logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].

  • reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000

  • reduce_on_plateau_reduction (float) – reduction in learning rate when encountering plateau. Defaults to 2.0.

  • reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateau learning rate scheduler. Defaults to 1e-5

  • weight_decay (float) – weight decay. Defaults to 0.0.

  • optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.

  • monotone_constraints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g. "0" for first position) to constraint (-1 for negative and +1 for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.

  • output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to lambda out: out["prediction"].

  • optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in torch.optim or pytorch_optimizer. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay). Defaults to “adam”.

Methods

forward(x)

Pass forward of network.

from_dataset(dataset, **kwargs)

Convenience function to create network from :py:class ~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.

log_interpretation(x, out, batch_idx)

Log interpretation of network predictions in tensorboard.

plot_interpretation(x, output, idx[, ax, ...])

Plot interpretation.

step(x, y, batch_idx)

Take training / validation step.

classmethod from_dataset(dataset: TimeSeriesDataSet, **kwargs)[source]#

Convenience function to create network from :py:class ~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.

Parameters:
  • dataset (TimeSeriesDataSet) – dataset where sole predictor is the target.

  • **kwargs – additional arguments to be passed to __init__ method.

Return type:

NBeats

forward(x: dict[str, Tensor]) dict[str, Tensor][source]#

Pass forward of network.

Parameters:

x (dict of str to torch.Tensor) – input from dataloader generated from TimeSeriesDataSet.

Returns:

output of model

Return type:

dict of str to torch.Tensor

log_interpretation(x, out, batch_idx)[source]#

Log interpretation of network predictions in tensorboard.

plot_interpretation(x: dict[str, Tensor], output: dict[str, Tensor], idx: int, ax=None, plot_seasonality_and_generic_on_secondary_axis: bool = False)[source]#

Plot interpretation.

Plot two panels: prediction and backcast vs actuals and decomposition of prediction into trend, seasonality and generic forecast.

Parameters:
  • x (dict of str to torch.Tensor) – network input

  • output (dict of str to torch.Tensor) – network output

  • idx (int) – index of sample for which to plot the interpretation.

  • ax (list of matplotlib.axes) – list of two matplotlib axes onto which to plot the interpretation. Defaults to None.

  • plot_seasonality_and_generic_on_secondary_axis (bool) – if to plot seasonality and generic forecast on secondary axis in second panel. Defaults to False.

Returns:

matplotlib figure

Return type:

matplotlib.figure.Figure

step(x, y, batch_idx) dict[str, Tensor][source]#

Take training / validation step.