Source code for pytorch_forecasting.models.baseline

"""
Baseline model.
"""
from typing import Any, Dict

import torch

from pytorch_forecasting.models import BaseModel


[docs]class Baseline(BaseModel): """ Baseline model that uses last known target value to make prediction. Example: .. code-block:: python from pytorch_forecasting import BaseModel, MAE # generating predictions predictions = Baseline().predict(dataloader) # calculate baseline performance in terms of mean absolute error (MAE) metric = MAE() model = Baseline() for x, y in dataloader: metric.update(model(x), y) metric.compute() """
[docs] def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Network forward pass. Args: x (Dict[str, torch.Tensor]): network input Returns: Dict[str, torch.Tensor]: netowrk outputs """ if isinstance(x["encoder_target"], (list, tuple)): # multiple targets prediction = [ self.forward_one_target( encoder_lengths=x["encoder_lengths"], decoder_lengths=x["decoder_lengths"], encoder_target=encoder_target, ) for encoder_target in x["encoder_target"] ] else: # one target prediction = self.forward_one_target( encoder_lengths=x["encoder_lengths"], decoder_lengths=x["decoder_lengths"], encoder_target=x["encoder_target"], ) return self.to_network_output(prediction=prediction)
def forward_one_target( self, encoder_lengths: torch.Tensor, decoder_lengths: torch.Tensor, encoder_target: torch.Tensor ): max_prediction_length = decoder_lengths.max() assert encoder_lengths.min() > 0, "Encoder lengths of at least 1 required to obtain last value" last_values = encoder_target[torch.arange(encoder_target.size(0)), encoder_lengths - 1] prediction = last_values[:, None].expand(-1, max_prediction_length) return prediction
[docs] def to_prediction(self, out: Dict[str, Any], use_metric: bool = True, **kwargs): return out.prediction
[docs] def to_quantiles(self, out: Dict[str, Any], use_metric: bool = True, **kwargs): return out.prediction[..., None]