Source code for pytorch_forecasting.models.nbeats._nbeats_adapter

"""
N-Beats model adapter for timeseries forecasting without covariates.
"""

from typing import Optional

import torch

from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.layers._nbeats._blocks import (
    NBEATSSeasonalBlock,
    NBEATSTrendBlock,
)
from pytorch_forecasting.metrics import MASE
from pytorch_forecasting.models.base_model import BaseModel
from pytorch_forecasting.utils._dependencies import _check_matplotlib


[docs] class NBeatsAdapter(BaseModel): """ Initialize NBeats Adapter. Parameters ---------- **kwargs additional arguments to :py:class:`~BaseModel`. """ # noqa: E501 def __init__( self, **kwargs, ): super().__init__(**kwargs)
[docs] def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Pass forward of network. Parameters ---------- x : dict of str to torch.Tensor input from dataloader generated from :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. Returns ------- dict of str to torch.Tensor output of model """ target = x["encoder_cont"][..., 0] timesteps = self.hparams.context_length + self.hparams.prediction_length generic_forecast = [ torch.zeros( (target.size(0), timesteps), dtype=torch.float32, device=self.device ) ] trend_forecast = [ torch.zeros( (target.size(0), timesteps), dtype=torch.float32, device=self.device ) ] seasonal_forecast = [ torch.zeros( (target.size(0), timesteps), dtype=torch.float32, device=self.device ) ] forecast = torch.zeros( (target.size(0), self.hparams.prediction_length), dtype=torch.float32, device=self.device, ) backcast = target # initialize backcast for i, block in enumerate(self.net_blocks): # evaluate block backcast_block, forecast_block = block(backcast) # add for interpretation full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) if isinstance(block, NBEATSTrendBlock): trend_forecast.append(full) elif isinstance(block, NBEATSSeasonalBlock): seasonal_forecast.append(full) else: generic_forecast.append(full) # update backcast and forecast backcast = ( backcast - backcast_block ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 forecast = forecast + forecast_block return self.to_network_output( prediction=self.transform_output( forecast.unsqueeze(-1), target_scale=x["target_scale"] ), backcast=self.transform_output( prediction=(target - backcast).unsqueeze(-1), target_scale=x["target_scale"], ), trend=self.transform_output( torch.stack(trend_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), seasonality=self.transform_output( torch.stack(seasonal_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), generic=self.transform_output( torch.stack(generic_forecast, dim=0).sum(0).unsqueeze(-1), target_scale=x["target_scale"], ), )
[docs] @classmethod def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): """ Convenience function to create network from :py:class `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. Parameters ---------- dataset : TimeSeriesDataSet dataset where sole predictor is the target. **kwargs additional arguments to be passed to ``__init__`` method. Returns ------- NBeats """ # noqa: E501 new_kwargs = { "prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length, } new_kwargs.update(kwargs) # validate arguments assert isinstance( dataset.target, str ), "only one target is allowed (passed as string to dataset)" assert not isinstance( dataset.target_normalizer, NaNLabelEncoder ), "only regression tasks are supported - target must not be categorical" assert dataset.min_encoder_length == dataset.max_encoder_length, ( "only fixed encoder length is allowed," " but min_encoder_length != max_encoder_length" ) assert dataset.max_prediction_length == dataset.min_prediction_length, ( "only fixed prediction length is allowed," " but max_prediction_length != min_prediction_length" ) assert ( dataset.randomize_length is None ), "length has to be fixed, but randomize_length is not None" assert ( not dataset.add_relative_time_idx ), "add_relative_time_idx has to be False" assert ( len(dataset.flat_categoricals) == 0 and len(dataset.reals) == 1 and len(dataset._time_varying_unknown_reals) == 1 and dataset._time_varying_unknown_reals[0] == dataset.target ), ( "The only variable as input should be the" " target which is part of time_varying_unknown_reals" ) # initialize class return super().from_dataset(dataset, **new_kwargs)
[docs] def step(self, x, y, batch_idx) -> dict[str, torch.Tensor]: """ Take training / validation step. """ log, out = super().step(x, y, batch_idx=batch_idx) if ( self.hparams.backcast_loss_ratio > 0 and not self.predicting ): # add loss from backcast backcast = out["backcast"] backcast_weight = ( self.hparams.backcast_loss_ratio * self.hparams.prediction_length / self.hparams.context_length ) backcast_weight = backcast_weight / (backcast_weight + 1) # normalize forecast_weight = 1 - backcast_weight if isinstance(self.loss, MASE): backcast_loss = ( self.loss(backcast, x["encoder_target"], x["decoder_target"]) * backcast_weight ) else: backcast_loss = ( self.loss(backcast, x["encoder_target"]) * backcast_weight ) label = ["val", "train"][self.training] self.log( f"{label}_backcast_loss", backcast_loss, on_epoch=True, on_step=self.training, batch_size=len(x["decoder_target"]), ) self.log( f"{label}_forecast_loss", log["loss"], on_epoch=True, on_step=self.training, batch_size=len(x["decoder_target"]), ) log["loss"] = log["loss"] * forecast_weight + backcast_loss self.log_interpretation(x, out, batch_idx=batch_idx) return log, out
[docs] def log_interpretation(self, x, out, batch_idx): """ Log interpretation of network predictions in tensorboard. """ mpl_available = _check_matplotlib("log_interpretation", raise_error=False) # Don't log figures if matplotlib or add_figure is not available if not mpl_available or not self._logger_supports("add_figure"): return None label = ["val", "train"][self.training] if self.log_interval > 0 and batch_idx % self.log_interval == 0: fig = self.plot_interpretation(x, out, idx=0) name = f"{label.capitalize()} interpretation of item 0 in " if self.training: name += f"step {self.global_step}" else: name += f"batch {batch_idx}" self.logger.experiment.add_figure(name, fig, global_step=self.global_step)
[docs] def plot_interpretation( self, x: dict[str, torch.Tensor], output: dict[str, torch.Tensor], idx: int, ax=None, plot_seasonality_and_generic_on_secondary_axis: bool = False, ): """ Plot interpretation. Plot two panels: prediction and backcast vs actuals and decomposition of prediction into trend, seasonality and generic forecast. Parameters ---------- x : dict of str to torch.Tensor network input output : dict of str to torch.Tensor network output idx : int index of sample for which to plot the interpretation. ax : list of matplotlib.axes list of two matplotlib axes onto which to plot the interpretation. Defaults to None. plot_seasonality_and_generic_on_secondary_axis : bool if to plot seasonality and generic forecast on secondary axis in second panel. Defaults to False. Returns ------- matplotlib.figure.Figure matplotlib figure """ # noqa: E501 _check_matplotlib("plot_interpretation") import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots(2, 1, figsize=(6, 8)) else: fig = ax[0].get_figure() time = torch.arange( -self.hparams.context_length, self.hparams.prediction_length ) # plot target vs prediction ax[0].plot( time, torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) .detach() .cpu(), label="target", ) ax[0].plot( time, torch.cat( [ output["backcast"][idx].detach(), output["prediction"][idx].detach(), ], dim=0, ).cpu(), label="prediction", ) ax[0].set_xlabel("Time") # plot blocks prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) next(prop_cycle) # prediction next(prop_cycle) # observations if plot_seasonality_and_generic_on_secondary_axis: ax2 = ax[1].twinx() ax2.set_ylabel("Seasonality / Generic") else: ax2 = ax[1] for title in ["trend", "seasonality", "generic"]: if title not in self.hparams.stack_types: continue if title == "trend": ax[1].plot( time, output[title][idx].detach().cpu(), label=title.capitalize(), c=next(prop_cycle)["color"], ) else: ax2.plot( time, output[title][idx].detach().cpu(), label=title.capitalize(), c=next(prop_cycle)["color"], ) ax[1].set_xlabel("Time") ax[1].set_ylabel("Decomposition") fig.legend() return fig