Source code for pytorch_forecasting.metrics.quantile

"""Quantile metrics for forecasting multiple quantiles per time step."""
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch

from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric


[docs]class QuantileLoss(MultiHorizonMetric): """ Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calculated as Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))`` """ def __init__( self, quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98], **kwargs, ): """ Quantile loss Args: quantiles: quantiles for metric """ super().__init__(quantiles=quantiles, **kwargs)
[docs] def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # calculate quantile loss losses = [] for i, q in enumerate(self.quantiles): errors = target - y_pred[..., i] losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) losses = 2 * torch.cat(losses, dim=2) return losses
[docs] def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: """ Convert network prediction into a point prediction. Args: y_pred: prediction output of network Returns: torch.Tensor: point prediction """ if y_pred.ndim == 3: idx = self.quantiles.index(0.5) y_pred = y_pred[..., idx] return y_pred
[docs] def to_quantiles(self, y_pred: torch.Tensor) -> torch.Tensor: """ Convert network prediction into a quantile prediction. Args: y_pred: prediction output of network Returns: torch.Tensor: prediction quantiles """ return y_pred