Source code for pytorch_forecasting.models.base._base_model_v2

########################################################################################
# Disclaimer: This baseclass is still work in progress and experimental, please
# use with care. This class is a basic skeleton of how the base classes may look like
# in the version-2.
########################################################################################


from typing import Any, Optional, Union
from warnings import warn

from lightning import Trainer
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_forecasting.callbacks.predict import PredictCallback
from pytorch_forecasting.metrics import Metric, MultiLoss
from pytorch_forecasting.utils._classproperty import classproperty


[docs] class BaseModel(LightningModule): """Base model for time series forecasting. Parameters ---------- loss : Descendants of ``pytorch_forecasting.metrics.Metric`` class Loss function to use for training. logging_metrics : Optional[List[nn.Module]], optional List of metrics to log during training, validation, and testing. optimizer : Optional[Union[Optimizer, str]], optional Optimizer to use for training. Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. optimizer_params : Optional[Dict], optional Parameters for the optimizer. lr_scheduler : Optional[str], optional Learning rate scheduler to use. Supported values: "reduce_lr_on_plateau", "step_lr". lr_scheduler_params : Optional[Dict], optional Parameters for the learning rate scheduler. """ def __init__( self, loss: Metric, logging_metrics: list[nn.Module] | None = None, optimizer: Optimizer | str | None = "adam", optimizer_params: dict | None = None, lr_scheduler: str | None = None, lr_scheduler_params: dict | None = None, ): super().__init__() self.loss = loss self.logging_metrics = nn.ModuleList( logging_metrics if logging_metrics is not None else [] ) self.optimizer = optimizer self.optimizer_params = optimizer_params if optimizer_params is not None else {} self.lr_scheduler = lr_scheduler self.lr_scheduler_params = ( lr_scheduler_params if lr_scheduler_params is not None else {} ) self.model_name = self.__class__.__name__ warn( f"The Model '{self.model_name}' is part of an experimental rework" "of the pytorch-forecasting model layer, scheduled for release with v2.0.0." " The API is not stable and may change without prior warning. " "This class is intended for beta testing and as a basic skeleton, " "but not for stable production use. " "Feedback and suggestions are very welcome in " "pytorch-forecasting issue 1736, " "https://github.com/sktime/pytorch-forecasting/issues/1736", UserWarning, ) @classproperty def pkg(cls): """Package class for the model.""" return cls._pkg()
[docs] def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Parameters ---------- x : Dict[str, torch.Tensor] Dictionary containing input tensors Returns ------- Dict[str, torch.Tensor] Dictionary containing output tensors """ raise NotImplementedError("Forward method must be implemented by subclass.")
[docs] def predict( self, dataloader: DataLoader, mode: str = "prediction", return_info: list[str] | None = None, mode_kwargs: dict[str, Any] = None, trainer_kwargs: dict[str, Any] = None, ) -> dict[str, torch.Tensor]: """ Generate predictions for new data using the `lightning.Trainer`. Parameters ---------- dataloader : DataLoader The dataloader containing the data to predict on. mode : str The prediction mode ("prediction", "quantiles", or "raw"). return_info : list[str], optional A list of additional information to return. mode_kwargs : dict[str, Any] Additional arguments for `to_prediction`/`to_quantiles`. trainer_kwargs: dict[str, Any] Additional arguments for `Trainer`. Returns ------- dict[str, torch.Tensor] A dictionary of prediction results. """ trainer_kwargs = trainer_kwargs or {} predict_callback = PredictCallback( mode=mode, return_info=return_info, mode_kwargs=mode_kwargs ) callbacks = trainer_kwargs.get("callbacks", []) if not isinstance(callbacks, list): callbacks = [callbacks] callbacks.append(predict_callback) trainer_kwargs["callbacks"] = callbacks trainer = Trainer(**trainer_kwargs) trainer.predict(self, dataloaders=dataloader) return predict_callback.result
[docs] def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to point forecasts.""" # todo: add MultiLoss support try: out = self.loss.to_prediction(out["prediction"], **kwargs) except TypeError: # in case passed kwargs do not exist out = self.loss.to_prediction(out["prediction"]) return out
[docs] def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor: """Converts raw model output to quantile forecasts.""" # todo: add MultiLoss support try: out = self.loss.to_quantiles(out["prediction"], **kwargs) except TypeError: # in case passed kwargs do not exist out = self.loss.to_quantiles(out["prediction"]) return out
[docs] def training_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: """ Training step for the model. Parameters ---------- batch : Tuple[Dict[str, torch.Tensor]] Batch of data containing input and target tensors. batch_idx : int Index of the batch. Returns ------- STEP_OUTPUT Dictionary containing the loss and other metrics. """ x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] loss = self.loss(y_hat, y) self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) self.log_metrics(y_hat, y, prefix="train") return {"loss": loss}
[docs] def validation_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: """ Validation step for the model. Parameters ---------- batch : Tuple[Dict[str, torch.Tensor]] Batch of data containing input and target tensors. batch_idx : int Index of the batch. Returns ------- STEP_OUTPUT Dictionary containing the loss and other metrics. """ x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] loss = self.loss(y_hat, y) self.log( "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True ) self.log_metrics(y_hat, y, prefix="val") return {"val_loss": loss}
[docs] def test_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: """ Test step for the model. Parameters ---------- batch : Tuple[Dict[str, torch.Tensor]] Batch of data containing input and target tensors. batch_idx : int Index of the batch. Returns ------- STEP_OUTPUT Dictionary containing the loss and other metrics. """ x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] loss = self.loss(y_hat, y) self.log( "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True ) self.log_metrics(y_hat, y, prefix="test") return {"test_loss": loss}
[docs] def predict_step( self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int = 0, ) -> torch.Tensor: """ Prediction step for the model. Parameters ---------- batch : Tuple[Dict[str, torch.Tensor]] Batch of data containing input tensors. batch_idx : int Index of the batch. dataloader_idx : int Index of the dataloader. Returns ------- torch.Tensor Predicted output tensor. """ x, _ = batch y_hat = self(x) return y_hat
[docs] def configure_optimizers(self) -> dict: """ Configure the optimizer and learning rate scheduler. Returns ------- Dict Dictionary containing the optimizer and scheduler configuration. """ optimizer = self._get_optimizer() if self.lr_scheduler is not None: scheduler = self._get_scheduler(optimizer) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "val_loss", }, } else: return {"optimizer": optimizer, "lr_scheduler": scheduler} return {"optimizer": optimizer}
def _get_optimizer(self) -> Optimizer: """ Get the optimizer based on the specified optimizer name and parameters. Returns ------- Optimizer The optimizer instance. """ if isinstance(self.optimizer, str): if self.optimizer.lower() == "adam": return torch.optim.Adam(self.parameters(), **self.optimizer_params) elif self.optimizer.lower() == "sgd": return torch.optim.SGD(self.parameters(), **self.optimizer_params) else: raise ValueError(f"Optimizer {self.optimizer} not supported.") elif isinstance(self.optimizer, Optimizer): return self.optimizer else: raise ValueError( "Optimizer must be either a string or " "an instance of torch.optim.Optimizer." ) def _get_scheduler( self, optimizer: Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: """ Get the lr scheduler based on the specified scheduler name and params. Parameters ---------- optimizer : Optimizer The optimizer instance. Returns ------- torch.optim.lr_scheduler._LRScheduler The learning rate scheduler instance. """ if self.lr_scheduler.lower() == "reduce_lr_on_plateau": return torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, **self.lr_scheduler_params ) elif self.lr_scheduler.lower() == "step_lr": return torch.optim.lr_scheduler.StepLR( optimizer, **self.lr_scheduler_params ) else: raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")
[docs] def log_metrics( self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" ) -> None: """ Log additional metrics during training, validation, or testing. Parameters ---------- y_hat : torch.Tensor Predicted output tensor. y : torch.Tensor Target output tensor. prefix : str Prefix for the logged metrics (e.g., "train", "val", "test"). """ for metric in self.logging_metrics: metric_value = metric(y_hat, y) self.log( f"{prefix}_{metric.__class__.__name__}", metric_value, on_step=False, on_epoch=True, prog_bar=True, logger=True, )