########################################################################################
# Disclaimer: This baseclass is still work in progress and experimental, please
# use with care. This class is a basic skeleton of how the base classes may look like
# in the version-2.
########################################################################################
from typing import Any, Optional, Union
from warnings import warn
from lightning import Trainer
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_forecasting.callbacks.predict import PredictCallback
from pytorch_forecasting.metrics import Metric, MultiLoss
from pytorch_forecasting.utils._classproperty import classproperty
[docs]
class BaseModel(LightningModule):
"""Base model for time series forecasting.
Parameters
----------
loss : Descendants of ``pytorch_forecasting.metrics.Metric`` class
Loss function to use for training.
logging_metrics : Optional[List[nn.Module]], optional
List of metrics to log during training, validation, and testing.
optimizer : Optional[Union[Optimizer, str]], optional
Optimizer to use for training.
Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`.
optimizer_params : Optional[Dict], optional
Parameters for the optimizer.
lr_scheduler : Optional[str], optional
Learning rate scheduler to use.
Supported values: "reduce_lr_on_plateau", "step_lr".
lr_scheduler_params : Optional[Dict], optional
Parameters for the learning rate scheduler.
"""
def __init__(
self,
loss: Metric,
logging_metrics: list[nn.Module] | None = None,
optimizer: Optimizer | str | None = "adam",
optimizer_params: dict | None = None,
lr_scheduler: str | None = None,
lr_scheduler_params: dict | None = None,
):
super().__init__()
self.loss = loss
self.logging_metrics = nn.ModuleList(
logging_metrics if logging_metrics is not None else []
)
self.optimizer = optimizer
self.optimizer_params = optimizer_params if optimizer_params is not None else {}
self.lr_scheduler = lr_scheduler
self.lr_scheduler_params = (
lr_scheduler_params if lr_scheduler_params is not None else {}
)
self.model_name = self.__class__.__name__
warn(
f"The Model '{self.model_name}' is part of an experimental rework"
"of the pytorch-forecasting model layer, scheduled for release with v2.0.0."
" The API is not stable and may change without prior warning. "
"This class is intended for beta testing and as a basic skeleton, "
"but not for stable production use. "
"Feedback and suggestions are very welcome in "
"pytorch-forecasting issue 1736, "
"https://github.com/sktime/pytorch-forecasting/issues/1736",
UserWarning,
)
@classproperty
def pkg(cls):
"""Package class for the model."""
return cls._pkg()
[docs]
def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forward pass of the model.
Parameters
----------
x : Dict[str, torch.Tensor]
Dictionary containing input tensors
Returns
-------
Dict[str, torch.Tensor]
Dictionary containing output tensors
"""
raise NotImplementedError("Forward method must be implemented by subclass.")
[docs]
def predict(
self,
dataloader: DataLoader,
mode: str = "prediction",
return_info: list[str] | None = None,
mode_kwargs: dict[str, Any] = None,
trainer_kwargs: dict[str, Any] = None,
) -> dict[str, torch.Tensor]:
"""
Generate predictions for new data using the `lightning.Trainer`.
Parameters
----------
dataloader : DataLoader
The dataloader containing the data to predict on.
mode : str
The prediction mode ("prediction", "quantiles", or "raw").
return_info : list[str], optional
A list of additional information to return.
mode_kwargs : dict[str, Any]
Additional arguments for `to_prediction`/`to_quantiles`.
trainer_kwargs: dict[str, Any]
Additional arguments for `Trainer`.
Returns
-------
dict[str, torch.Tensor]
A dictionary of prediction results.
"""
trainer_kwargs = trainer_kwargs or {}
predict_callback = PredictCallback(
mode=mode, return_info=return_info, mode_kwargs=mode_kwargs
)
callbacks = trainer_kwargs.get("callbacks", [])
if not isinstance(callbacks, list):
callbacks = [callbacks]
callbacks.append(predict_callback)
trainer_kwargs["callbacks"] = callbacks
trainer = Trainer(**trainer_kwargs)
trainer.predict(self, dataloaders=dataloader)
return predict_callback.result
[docs]
def to_prediction(self, out: dict[str, Any], **kwargs) -> torch.Tensor:
"""Converts raw model output to point forecasts."""
# todo: add MultiLoss support
try:
out = self.loss.to_prediction(out["prediction"], **kwargs)
except TypeError: # in case passed kwargs do not exist
out = self.loss.to_prediction(out["prediction"])
return out
[docs]
def to_quantiles(self, out: dict[str, Any], **kwargs) -> torch.Tensor:
"""Converts raw model output to quantile forecasts."""
# todo: add MultiLoss support
try:
out = self.loss.to_quantiles(out["prediction"], **kwargs)
except TypeError: # in case passed kwargs do not exist
out = self.loss.to_quantiles(out["prediction"])
return out
[docs]
def training_step(
self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Training step for the model.
Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.
Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="train")
return {"loss": loss}
[docs]
def validation_step(
self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Validation step for the model.
Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.
Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="val")
return {"val_loss": loss}
[docs]
def test_step(
self, batch: tuple[dict[str, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""
Test step for the model.
Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input and target tensors.
batch_idx : int
Index of the batch.
Returns
-------
STEP_OUTPUT
Dictionary containing the loss and other metrics.
"""
x, y = batch
y_hat_dict = self(x)
y_hat = y_hat_dict["prediction"]
loss = self.loss(y_hat, y)
self.log(
"test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self.log_metrics(y_hat, y, prefix="test")
return {"test_loss": loss}
[docs]
def predict_step(
self,
batch: tuple[dict[str, torch.Tensor]],
batch_idx: int,
dataloader_idx: int = 0,
) -> torch.Tensor:
"""
Prediction step for the model.
Parameters
----------
batch : Tuple[Dict[str, torch.Tensor]]
Batch of data containing input tensors.
batch_idx : int
Index of the batch.
dataloader_idx : int
Index of the dataloader.
Returns
-------
torch.Tensor
Predicted output tensor.
"""
x, _ = batch
y_hat = self(x)
return y_hat
def _get_optimizer(self) -> Optimizer:
"""
Get the optimizer based on the specified optimizer name and parameters.
Returns
-------
Optimizer
The optimizer instance.
"""
if isinstance(self.optimizer, str):
if self.optimizer.lower() == "adam":
return torch.optim.Adam(self.parameters(), **self.optimizer_params)
elif self.optimizer.lower() == "sgd":
return torch.optim.SGD(self.parameters(), **self.optimizer_params)
else:
raise ValueError(f"Optimizer {self.optimizer} not supported.")
elif isinstance(self.optimizer, Optimizer):
return self.optimizer
else:
raise ValueError(
"Optimizer must be either a string or "
"an instance of torch.optim.Optimizer."
)
def _get_scheduler(
self, optimizer: Optimizer
) -> torch.optim.lr_scheduler._LRScheduler:
"""
Get the lr scheduler based on the specified scheduler name and params.
Parameters
----------
optimizer : Optimizer
The optimizer instance.
Returns
-------
torch.optim.lr_scheduler._LRScheduler
The learning rate scheduler instance.
"""
if self.lr_scheduler.lower() == "reduce_lr_on_plateau":
return torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, **self.lr_scheduler_params
)
elif self.lr_scheduler.lower() == "step_lr":
return torch.optim.lr_scheduler.StepLR(
optimizer, **self.lr_scheduler_params
)
else:
raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")
[docs]
def log_metrics(
self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val"
) -> None:
"""
Log additional metrics during training, validation, or testing.
Parameters
----------
y_hat : torch.Tensor
Predicted output tensor.
y : torch.Tensor
Target output tensor.
prefix : str
Prefix for the logged metrics (e.g., "train", "val", "test").
"""
for metric in self.logging_metrics:
metric_value = metric(y_hat, y)
self.log(
f"{prefix}_{metric.__class__.__name__}",
metric_value,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)