# Source code for pytorch_forecasting.metrics.quantile

```"""Quantile metrics for forecasting multiple quantiles per time step."""
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch

from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric

[docs]class QuantileLoss(MultiHorizonMetric):
"""
Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as

Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
"""

def __init__(
self,
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
**kwargs,
):
"""
Quantile loss

Args:
quantiles: quantiles for metric
"""
super().__init__(quantiles=quantiles, **kwargs)

[docs]    def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# calculate quantile loss
losses = []
for i, q in enumerate(self.quantiles):
errors = target - y_pred[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
losses = 2 * torch.cat(losses, dim=2)

return losses

[docs]    def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.

Args:
y_pred: prediction output of network

Returns:
torch.Tensor: point prediction
"""
if y_pred.ndim == 3:
idx = self.quantiles.index(0.5)
y_pred = y_pred[..., idx]
return y_pred

[docs]    def to_quantiles(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.

Args:
y_pred: prediction output of network

Returns:
torch.Tensor: prediction quantiles
"""
return y_pred
```