"""
Implementation of metrics for (mulit-horizon) timeseries forecasting.
"""
from typing import Dict, List, Tuple, Union
import warnings
import scipy.stats
from sklearn.base import BaseEstimator
import torch
from torch import distributions
import torch.nn.functional as F
from torch.nn.utils import rnn
from torchmetrics import Metric as LightningMetric
from torchmetrics.metric import CompositionalMetric
from pytorch_forecasting.utils import create_mask, unpack_sequence, unsqueeze_like
[docs]class Metric(LightningMetric):
"""
Base metric class that has basic functions that can handle predicting quantiles and operate in log space.
See the `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/latest/metrics.html>`_
for details of how to implement a new metric
Other metrics should inherit from this base class
"""
def __init__(self, name: str = None, quantiles: List[float] = None, reduction="mean", **kwargs):
"""
Initialize metric
Args:
name (str): metric name. Defaults to class name.
quantiles (List[float], optional): quantiles for probability range. Defaults to None.
reduction (str, optional): Reduction, "none", "mean" or "sqrt-mean". Defaults to "mean".
"""
self.quantiles = quantiles
self.reduction = reduction
if name is None:
name = self.__class__.__name__
self.name = name
super().__init__(**kwargs)
[docs] def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor):
raise NotImplementedError()
[docs] def compute(self) -> torch.Tensor:
"""
Abstract method that calcualtes metric
Should be overriden in derived classes
Args:
y_pred: network output
y_actual: actual values
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
raise NotImplementedError()
[docs] def rescale_parameters(
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
) -> torch.Tensor:
"""
Rescale normalized parameters into the scale required for the output.
Args:
parameters (torch.Tensor): normalized parameters (indexed by last dimension)
target_scale (torch.Tensor): scale of parameters (n_batch_samples x (center, scale))
encoder (BaseEstimator): original encoder that normalized the target in the first place
Returns:
torch.Tensor: parameters in real/not normalized space
"""
return encoder(dict(prediction=parameters, target_scale=target_scale))
[docs] def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Args:
y_pred: prediction output of network
Returns:
torch.Tensor: point prediction
"""
if y_pred.ndim == 3:
if self.quantiles is None:
assert y_pred.size(-1) == 1, "Prediction should only have one extra dimension"
y_pred = y_pred[..., 0]
else:
y_pred = y_pred.mean(-1)
return y_pred
[docs] def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Args:
y_pred: prediction output of network
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
as defined in the class initialization.
Returns:
torch.Tensor: prediction quantiles
"""
if quantiles is None:
quantiles = self.quantiles
if y_pred.ndim == 2:
return y_pred.unsqueeze(-1)
elif y_pred.ndim == 3:
if y_pred.size(2) > 1: # single dimension means all quantiles are the same
assert quantiles is not None, "quantiles are not defined"
y_pred = torch.quantile(y_pred, torch.tensor(quantiles, device=y_pred.device), dim=2).permute(1, 2, 0)
return y_pred
else:
raise ValueError(f"prediction has 1 or more than 3 dimensions: {y_pred.ndim}")
def __add__(self, metric: LightningMetric):
composite_metric = CompositeMetric(metrics=[self])
new_metric = composite_metric + metric
return new_metric
def __mul__(self, multiplier: float):
new_metric = CompositeMetric(metrics=[self], weights=[multiplier])
return new_metric
__rmul__ = __mul__
[docs]class TorchMetricWrapper(Metric):
"""
Wrap a torchmetric to work with PyTorch Forecasting.
Does not support weighting of errors and only supports metrics for point predictions.
"""
def __init__(self, torchmetric: LightningMetric, reduction: str = None, **kwargs):
"""
Args:
torchmetric (LightningMetric): Torchmetric to wrap.
reduction (str, optional): use reduction with torchmetric directly. Defaults to None.
"""
super().__init__(**kwargs)
if reduction is not None:
raise ValueError("use reduction with torchmetric directly")
self.torchmetric = torchmetric
def _sync_dist(self, dist_sync_fn=None, process_group=None) -> None:
# No syncing required here. syncing will be done in metric_a and metric_b
pass
[docs] def reset(self) -> None:
self.torchmetric.reset()
[docs] def persistent(self, mode: bool = False) -> None:
self.torchmetric.persistent(mode=mode)
def _convert(self, y_pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# unpack target into target and weights
if isinstance(target, (list, tuple)) and not isinstance(target, rnn.PackedSequence):
target, weight = target
if weight is not None:
raise NotImplementedError(
"Weighting is not supported for pure torchmetrics - "
"implement a custom version or use pytorch-forecasting metrics"
)
# convert to point prediction - limits applications of class
y_pred = self.to_prediction(y_pred)
# unpack target if it is PackedSequence
if isinstance(target, rnn.PackedSequence):
target, lengths = unpack_sequence(target)
# create mask for different lengths
length_mask = create_mask(target.size(1), lengths, inverse=True)
target = target.masked_select(length_mask)
y_pred = y_pred.masked_select(length_mask)
y_pred = y_pred.flatten()
target = target.flatten()
return y_pred, target
[docs] def update(self, y_pred: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
# flatten target and prediction
y_pred_flattened, target_flattened = self._convert(y_pred, target)
# update metric
self.torchmetric.update(y_pred_flattened, target_flattened, **kwargs)
[docs] def forward(self, y_pred, target, **kwargs):
# need this explicitly to avoid backpropagation errors because of sketchy caching
y_pred_flattened, target_flattened = self._convert(y_pred, target)
return self.torchmetric.forward(y_pred_flattened, target_flattened, **kwargs)
[docs] def compute(self):
res = self.torchmetric.compute()
return res
def __repr__(self):
return f"WrappedTorchmetric({repr(self.torchmetric)})"
[docs]def convert_torchmetric_to_pytorch_forecasting_metric(metric: LightningMetric) -> Metric:
"""
If necessary, convert a torchmetric to a PyTorch Forecasting metric that
works with PyTorch Forecasting models.
Args:
metric (LightningMetric): metric to (potentially) convert
Returns:
Metric: PyTorch Forecasting metric
"""
if not isinstance(metric, (Metric, MultiLoss, CompositeMetric)):
return TorchMetricWrapper(metric)
else:
return metric
[docs]class MultiLoss(LightningMetric):
"""
Metric that can be used with muliple metrics.
"""
def __init__(self, metrics: List[LightningMetric], weights: List[float] = None):
"""
Args:
metrics (List[LightningMetric], optional): list of metrics to combine.
weights (List[float], optional): list of weights / multipliers for weights. Defaults to 1.0 for all metrics.
"""
assert len(metrics) > 0, "at least one metric has to be specified"
if weights is None:
weights = [1.0 for _ in metrics]
assert len(weights) == len(metrics), "Number of weights has to match number of metrics"
self.metrics = [convert_torchmetric_to_pytorch_forecasting_metric(m) for m in metrics]
self.weights = weights
super().__init__()
def __repr__(self):
name = (
f"{self.__class__.__name__}("
+ ", ".join([f"{w:.3g} * {repr(m)}" if w != 1.0 else repr(m) for w, m in zip(self.weights, self.metrics)])
+ ")"
)
return name
def __iter__(self):
"""
Iterate over metrics.
"""
return iter(self.metrics)
def __len__(self) -> int:
"""
Number of metrics.
Returns:
int: number of metrics
"""
return len(self.metrics)
[docs] def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs):
"""
Update composite metric
Args:
y_pred: network output
y_actual: actual values
**kwargs: arguments to update function
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
for idx, metric in enumerate(self.metrics):
try:
metric.update(
y_pred[idx],
(y_actual[0][idx], y_actual[1]),
**{
name: value[idx] if isinstance(value, (list, tuple)) else value
for name, value in kwargs.items()
},
)
except TypeError: # silently update without kwargs if not supported
metric.update(y_pred[idx], (y_actual[0][idx], y_actual[1]))
[docs] def compute(self) -> torch.Tensor:
"""
Get metric
Returns:
torch.Tensor: metric
"""
results = []
for weight, metric in zip(self.weights, self.metrics):
results.append(metric.compute() * weight)
if len(results) == 1:
results = results[0]
else:
results = torch.stack(results, dim=0).sum(0)
return results
[docs] def to_prediction(self, y_pred: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Will use first metric in ``metrics`` attribute to calculate result.
Args:
y_pred: prediction output of network
**kwargs: arguments for metrics
Returns:
torch.Tensor: point prediction
"""
result = []
for idx, metric in enumerate(self.metrics):
try:
result.append(metric.to_prediction(y_pred[idx], **kwargs))
except TypeError:
result.append(metric.to_prediction(y_pred[idx]))
return result
[docs] def to_quantiles(self, y_pred: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Will use first metric in ``metrics`` attribute to calculate result.
Args:
y_pred: prediction output of network
**kwargs: parameters to each metric's ``to_quantiles()`` method
Returns:
torch.Tensor: prediction quantiles
"""
result = []
for idx, metric in enumerate(self.metrics):
try:
result.append(metric.to_quantiles(y_pred[idx], **kwargs))
except TypeError:
result.append(metric.to_quantiles(y_pred[idx]))
return result
def __getitem__(self, idx: int):
"""
Return metric.
Args:
idx (int): metric index
"""
return self.metrics[idx]
def __getattr__(self, name: str):
"""
Return dynamically attributes.
Return attributes if defined in this class. If not, create dynamically attributes based on
attributes of underlying metrics that are lists. Create functions if necessary.
Arguments to functions are distributed to the functions if they are lists and their length
matches the number of metrics. Otherwise, they are directly passed to each callable of the
metrics
Args:
name (str): name of attribute
Returns:
attributes of this class or list of attributes of underlying class
"""
try:
return super().__getattr__(name)
except AttributeError as e:
attribute_exists = all([hasattr(metric, name) for metric in self.metrics])
if attribute_exists:
# check if to return callable or not and return function if yes
if callable(getattr(self.metrics[0], name)):
n = len(self.metrics)
def func(*args, **kwargs):
# if arg/kwarg is list and of length metric, then apply each part to a metric. otherwise
# pass it directly to all metrics
results = []
for idx, m in enumerate(self.metrics):
new_args = [
arg[idx]
if isinstance(arg, (list, tuple))
and not isinstance(arg, rnn.PackedSequence)
and len(arg) == n
else arg
for arg in args
]
new_kwargs = {
key: val[idx]
if isinstance(val, list) and not isinstance(val, rnn.PackedSequence) and len(val) == n
else val
for key, val in kwargs.items()
}
results.append(getattr(m, name)(*new_args, **new_kwargs))
return results
return func
else:
# else return list of attributes
return [getattr(metric, name) for metric in self.metrics]
else: # attribute does not exist for all metrics
raise e
[docs]class CompositeMetric(LightningMetric):
"""
Metric that combines multiple metrics.
Metric does not have to be called explicitly but is automatically created when adding and multiplying metrics
with each other.
Example:
.. code-block:: python
composite_metric = SMAPE() + 0.4 * MAE()
"""
def __init__(self, metrics: List[LightningMetric] = [], weights: List[float] = None):
"""
Args:
metrics (List[LightningMetric], optional): list of metrics to combine. Defaults to [].
weights (List[float], optional): list of weights / multipliers for weights. Defaults to 1.0 for all metrics.
"""
if weights is None:
weights = [1.0 for _ in metrics]
assert len(weights) == len(metrics), "Number of weights has to match number of metrics"
self.metrics = metrics
self.weights = weights
super().__init__()
def __repr__(self):
name = " + ".join([f"{w:.3g} * {repr(m)}" if w != 1.0 else repr(m) for w, m in zip(self.weights, self.metrics)])
return name
[docs] def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor):
"""
Update composite metric
Args:
y_pred: network output
y_actual: actual values
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
for metric in self.metrics:
metric.update(y_pred, y_actual)
[docs] def compute(self) -> torch.Tensor:
"""
Get metric
Returns:
torch.Tensor: metric
"""
results = []
for weight, metric in zip(self.weights, self.metrics):
results.append(metric.compute() * weight)
if len(results) == 1:
results = results[0]
else:
results = torch.stack(results, dim=0).sum(0)
return results
[docs] def to_prediction(self, y_pred: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Will use first metric in ``metrics`` attribute to calculate result.
Args:
y_pred: prediction output of network
**kwargs: parameters to first metric `to_prediction` method
Returns:
torch.Tensor: point prediction
"""
return self.metrics[0].to_prediction(y_pred, **kwargs)
[docs] def to_quantiles(self, y_pred: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Will use first metric in ``metrics`` attribute to calculate result.
Args:
y_pred: prediction output of network
**kwargs: parameters to first metric's ``to_quantiles()`` method
Returns:
torch.Tensor: prediction quantiles
"""
return self.metrics[0].to_quantiles(y_pred, **kwargs)
def __add__(self, metric: LightningMetric):
if isinstance(metric, self.__class__):
self.metrics.extend(metric.metrics)
self.weights.extend(metric.weights)
else:
self.metrics.append(metric)
self.weights.append(1.0)
return self
def __mul__(self, multiplier: float):
self.weights = [w * multiplier for w in self.weights]
return self
__rmul__ = __mul__
[docs]class AggregationMetric(Metric):
"""
Calculate metric on mean prediction and actuals.
"""
def __init__(self, metric: Metric, **kwargs):
"""
Args:
metric (Metric): metric which to calculate on aggreation.
"""
super().__init__(**kwargs)
self.metric = metric
[docs] def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
"""
Calculate composite metric
Args:
y_pred: network output
y_actual: actual values
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
# extract target and weight
if isinstance(y_actual, (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence):
target, weight = y_actual
else:
target = y_actual
weight = None
# handle rnn sequence as target
if isinstance(target, rnn.PackedSequence):
target, lengths = rnn.pad_packed_sequence(target, batch_first=True)
# batch sizes reside on the CPU by default -> we need to bring them to GPU
lengths = lengths.to(target.device)
# calculate mask for time steps
length_mask = create_mask(target.size(1), lengths, inverse=True)
# modify weight
if weight is None:
weight = length_mask
else:
weight = weight * length_mask
if weight is None:
y_mean = target.mean(0)
y_pred_mean = y_pred.mean(0)
else:
# calculate weighted sums
y_mean = (target * unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0)
y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0)
y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0), y_pred_sum)
# update metric. unsqueeze first batch dimension (as batches are collapsed)
self.metric.update(y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0))
[docs] def compute(self):
return self.metric.compute()
[docs]class MultiHorizonMetric(Metric):
"""
Abstract class for defining metric for a multihorizon forecast
"""
def __init__(self, reduction: str = "mean", **kwargs) -> None:
super().__init__(reduction=reduction, **kwargs)
self.add_state("losses", default=torch.tensor(0.0), dist_reduce_fx="sum" if reduction != "none" else "cat")
self.add_state("lengths", default=torch.tensor(0), dist_reduce_fx="sum" if reduction != "none" else "mean")
[docs] def loss(self, y_pred: Dict[str, torch.Tensor], target: torch.Tensor) -> torch.Tensor:
"""
Calculate loss without reduction. Override in derived classes
Args:
y_pred: network output
y_actual: actual values
Returns:
torch.Tensor: loss/metric as a single number for backpropagation
"""
raise NotImplementedError()
[docs] def update(self, y_pred, target):
"""
Update method of metric that handles masking of values.
Do not override this method but :py:meth:`~loss` instead
Args:
y_pred (Dict[str, torch.Tensor]): network output
target (Union[torch.Tensor, rnn.PackedSequence]): actual values
Returns:
torch.Tensor: loss as a single number for backpropagation
"""
# unpack weight
if isinstance(target, (list, tuple)) and not isinstance(target, rnn.PackedSequence):
target, weight = target
else:
weight = None
# unpack target
if isinstance(target, rnn.PackedSequence):
target, lengths = unpack_sequence(target)
else:
lengths = torch.full((target.size(0),), fill_value=target.size(1), dtype=torch.long, device=target.device)
losses = self.loss(y_pred, target)
# weight samples
if weight is not None:
losses = losses * unsqueeze_like(weight, losses)
self._update_losses_and_lengths(losses, lengths)
def _update_losses_and_lengths(self, losses: torch.Tensor, lengths: torch.Tensor):
losses = self.mask_losses(losses, lengths)
if self.reduction == "none":
if self.losses.ndim == 0:
self.losses = losses
self.lengths = lengths
else:
self.losses = torch.cat([self.losses, losses], dim=0)
self.lengths = torch.cat([self.lengths, lengths], dim=0)
else:
losses = losses.sum()
if not torch.isfinite(losses):
losses = torch.tensor(1e9, device=losses.device)
warnings.warn("Loss is not finite. Resetting it to 1e9")
self.losses = self.losses + losses
self.lengths = self.lengths + lengths.sum()
[docs] def compute(self):
loss = self.reduce_loss(self.losses, lengths=self.lengths)
return loss
[docs] def mask_losses(self, losses: torch.Tensor, lengths: torch.Tensor, reduction: str = None) -> torch.Tensor:
"""
Mask losses.
Args:
losses (torch.Tensor): total loss. first dimenion are samples, second timesteps
lengths (torch.Tensor): total length
reduction (str, optional): type of reduction. Defaults to ``self.reduction``.
Returns:
torch.Tensor: masked losses
"""
if reduction is None:
reduction = self.reduction
if losses.ndim > 0:
# mask loss
mask = torch.arange(losses.size(1), device=losses.device).unsqueeze(0) >= lengths.unsqueeze(-1)
if losses.ndim > 2:
mask = mask.unsqueeze(-1)
dim_normalizer = losses.size(-1)
else:
dim_normalizer = 1.0
# reduce to one number
if reduction == "none":
losses = losses.masked_fill(mask, float("nan"))
else:
losses = losses.masked_fill(mask, 0.0) / dim_normalizer
return losses
[docs] def reduce_loss(self, losses: torch.Tensor, lengths: torch.Tensor, reduction: str = None) -> torch.Tensor:
"""
Reduce loss.
Args:
losses (torch.Tensor): total loss. first dimenion are samples, second timesteps
lengths (torch.Tensor): total length
reduction (str, optional): type of reduction. Defaults to ``self.reduction``.
Returns:
torch.Tensor: reduced loss
"""
if reduction is None:
reduction = self.reduction
if reduction == "none":
return losses # return immediately, no checks
elif reduction == "mean":
loss = losses.sum() / lengths.sum()
elif reduction == "sqrt-mean":
loss = losses.sum() / lengths.sum()
loss = loss.sqrt()
else:
raise ValueError(f"reduction {reduction} unknown")
assert not torch.isnan(loss), (
"Loss should not be nan - i.e. something went wrong "
"in calculating the loss (e.g. log of a negative number)"
)
assert torch.isfinite(
loss
), "Loss should not be infinite - i.e. something went wrong (e.g. input is not in log space)"
return loss
[docs]class PoissonLoss(MultiHorizonMetric):
"""
Poisson loss for count data
"""
[docs] def loss(self, y_pred: Dict[str, torch.Tensor], target: torch.Tensor) -> torch.Tensor:
return F.poisson_nll_loss(
super().to_prediction(y_pred), target, log_input=True, full=False, eps=1e-6, reduction="none"
)
[docs] def to_prediction(self, out: Dict[str, torch.Tensor]):
rate = torch.exp(super().to_prediction(out))
return rate
[docs] def to_quantiles(self, out: Dict[str, torch.Tensor], quantiles=None):
if quantiles is None:
if self.quantiles is None:
quantiles = [0.5]
else:
quantiles = self.quantiles
predictions = super().to_prediction(out)
return torch.stack(
[torch.tensor(scipy.stats.poisson(predictions.detach().numpy()).ppf(q)) for q in quantiles], dim=-1
).to(predictions.device)
[docs]class QuantileLoss(MultiHorizonMetric):
"""
Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as
Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
"""
def __init__(
self,
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
**kwargs,
):
"""
Quantile loss
Args:
quantiles: quantiles for metric
"""
super().__init__(quantiles=quantiles, **kwargs)
[docs] def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# calculate quantile loss
losses = []
for i, q in enumerate(self.quantiles):
errors = target - y_pred[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
losses = torch.cat(losses, dim=2)
return losses
[docs] def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Args:
y_pred: prediction output of network
Returns:
torch.Tensor: point prediction
"""
if y_pred.ndim == 3:
idx = self.quantiles.index(0.5)
y_pred = y_pred[..., idx]
return y_pred
[docs] def to_quantiles(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Args:
y_pred: prediction output of network
Returns:
torch.Tensor: prediction quantiles
"""
return y_pred
[docs]class SMAPE(MultiHorizonMetric):
"""
Symmetric mean absolute percentage. Assumes ``y >= 0``.
Defined as ``2*(y - y_pred).abs() / (y.abs() + y_pred.abs())``
"""
[docs] def loss(self, y_pred, target):
y_pred = self.to_prediction(y_pred)
loss = 2 * (y_pred - target).abs() / (y_pred.abs() + target.abs() + 1e-8)
return loss
[docs]class MAPE(MultiHorizonMetric):
"""
Mean absolute percentage. Assumes ``y >= 0``.
Defined as ``(y - y_pred).abs() / y.abs()``
"""
[docs] def loss(self, y_pred, target):
loss = (self.to_prediction(y_pred) - target).abs() / (target.abs() + 1e-8)
return loss
[docs]class MAE(MultiHorizonMetric):
"""
Mean average absolute error.
Defined as ``(y_pred - target).abs()``
"""
[docs] def loss(self, y_pred, target):
loss = (self.to_prediction(y_pred) - target).abs()
return loss
[docs]class CrossEntropy(MultiHorizonMetric):
"""
Cross entropy loss for classification.
"""
[docs] def loss(self, y_pred, target):
loss = F.cross_entropy(y_pred.view(-1, y_pred.size(-1)), target.view(-1), reduction="none").view(
-1, target.size(-1)
)
return loss
[docs] def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Returns best label
Args:
y_pred: prediction output of network
Returns:
torch.Tensor: point prediction
"""
return y_pred.argmax(dim=-1)
[docs] def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Args:
y_pred: prediction output of network
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
as defined in the class initialization.
Returns:
torch.Tensor: prediction quantiles
"""
return y_pred
[docs]class RMSE(MultiHorizonMetric):
"""
Root mean square error
Defined as ``(y_pred - target)**2``
"""
def __init__(self, reduction="sqrt-mean", **kwargs):
super().__init__(reduction=reduction, **kwargs)
[docs] def loss(self, y_pred: Dict[str, torch.Tensor], target):
loss = torch.pow(self.to_prediction(y_pred) - target, 2)
return loss
[docs]class MASE(MultiHorizonMetric):
"""
Mean absolute scaled error
Defined as ``(y_pred - target).abs() / all_targets[:, :-1] - all_targets[:, 1:]).mean(1)``.
``all_targets`` are here the concatenated encoder and decoder targets
"""
[docs] def update(
self,
y_pred,
target,
encoder_target,
encoder_lengths=None,
) -> torch.Tensor:
"""
Update metric that handles masking of values.
Args:
y_pred (Dict[str, torch.Tensor]): network output
target (Tuple[Union[torch.Tensor, rnn.PackedSequence], torch.Tensor]): tuple of actual values and weights
encoder_target (Union[torch.Tensor, rnn.PackedSequence]): historic actual values
encoder_lengths (torch.Tensor): optional encoder lengths, not necessary if encoder_target
is rnn.PackedSequence. Assumed encoder_target is torch.Tensor
Returns:
torch.Tensor: loss as a single number for backpropagation
"""
# unpack weight
if isinstance(target, (list, tuple)):
weight = target[1]
target = target[0]
else:
weight = None
# unpack target
if isinstance(target, rnn.PackedSequence):
target, lengths = unpack_sequence(target)
else:
lengths = torch.full((target.size(0),), fill_value=target.size(1), dtype=torch.long, device=target.device)
# determine lengths for encoder
if encoder_lengths is None:
encoder_target, encoder_lengths = unpack_sequence(target)
else:
assert isinstance(encoder_target, torch.Tensor)
assert not target.requires_grad
# calculate loss with "none" reduction
scaling = self.calculate_scaling(target, lengths, encoder_target, encoder_lengths)
losses = self.loss(y_pred, target, scaling)
# weight samples
if weight is not None:
losses = losses * weight.unsqueeze(-1)
self._update_losses_and_lengths(losses, lengths)
[docs] def loss(self, y_pred, target, scaling):
return (self.to_prediction(y_pred) - target).abs() / scaling.unsqueeze(-1)
def calculate_scaling(self, target, lengths, encoder_target, encoder_lengths):
# calcualte mean(abs(diff(targets)))
eps = 1e-6
batch_size = target.size(0)
total_lengths = lengths + encoder_lengths
assert (total_lengths > 1).all(), "Need at least 2 target values to be able to calculate MASE"
max_length = target.size(1) + encoder_target.size(1)
if (total_lengths != max_length).any(): # if decoder or encoder targets have sequences of different lengths
targets = torch.cat(
[
encoder_target,
torch.zeros(batch_size, target.size(1), device=target.device, dtype=encoder_target.dtype),
],
dim=1,
)
target_index = torch.arange(target.size(1), device=target.device, dtype=torch.long).unsqueeze(0).expand(
batch_size, -1
) + encoder_lengths.unsqueeze(-1)
targets.scatter_(dim=1, src=target, index=target_index)
else:
targets = torch.cat([encoder_target, target], dim=1)
# take absolute difference
diffs = (targets[:, :-1] - targets[:, 1:]).abs()
# set last difference to 0
not_maximum_length = total_lengths != max_length
zero_correction_indices = total_lengths[not_maximum_length] - 1
if len(zero_correction_indices) > 0:
diffs[
torch.arange(batch_size, dtype=torch.long, device=diffs.device)[not_maximum_length],
zero_correction_indices,
] = 0.0
# calculate mean over differences
scaling = diffs.sum(1) / total_lengths + eps
return scaling
[docs]class DistributionLoss(MultiHorizonMetric):
"""
DistributionLoss base class.
Class should be inherited for all distribution losses, i.e. if a network predicts
the parameters of a probability distribution, DistributionLoss can be used to
score those parameters and calculate loss for given true values.
Define two class attributes in a child class:
Attributes:
distribution_class (distributions.Distribution): torch probability distribution
distribution_arguments (List[str]): list of parameter names for the distribution
Further, implement the methods :py:meth:`~map_x_to_distribution` and :py:meth:`~rescale_parameters`.
"""
distribution_class: distributions.Distribution
distribution_arguments: List[str]
def __init__(
self, name: str = None, quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98], reduction="mean"
):
"""
Initialize metric
Args:
name (str): metric name. Defaults to class name.
quantiles (List[float], optional): quantiles for probability range.
Defaults to [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98].
reduction (str, optional): Reduction, "none", "mean" or "sqrt-mean". Defaults to "mean".
"""
super().__init__(name=name, quantiles=quantiles, reduction=reduction)
[docs] def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Distribution:
"""
Map the a tensor of parameters to a probability distribution.
Args:
x (torch.Tensor): parameters for probability distribution. Last dimension will index the parameters
Returns:
distributions.Distribution: torch probability distribution as defined in the
class attribute ``distribution_class``
"""
raise NotImplementedError("implement this method")
[docs] def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
"""
Calculate negative likelihood
Args:
y_pred: network output
y_actual: actual values
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
distribution = self.map_x_to_distribution(y_pred)
loss = -distribution.log_prob(y_actual)
return loss
[docs] def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.
Args:
y_pred: prediction output of network
Returns:
torch.Tensor: mean prediction
"""
distribution = self.map_x_to_distribution(y_pred)
return distribution.mean
[docs] def sample(self, y_pred, n_samples: int) -> torch.Tensor:
"""
Sample from distribution.
Args:
y_pred: prediction output of network (shape batch_size x n_timesteps x n_paramters)
n_samples (int): number of samples to draw
Returns:
torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples)
"""
dist = self.map_x_to_distribution(y_pred)
samples = dist.sample((n_samples,))
if samples.ndim == 3:
samples = samples.permute(1, 2, 0)
elif samples.ndim == 2:
samples = samples.transpose(0, 1)
return samples
[docs] def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None, n_samples: int = 100) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Args:
y_pred: prediction output of network
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
as defined in the class initialization.
n_samples (int): number of samples to draw for quantiles. Defaults to 100.
Returns:
torch.Tensor: prediction quantiles (last dimension)
"""
if quantiles is None:
quantiles = self.quantiles
try:
distribution = self.map_x_to_distribution(y_pred)
quantiles = distribution.icdf(torch.tensor(quantiles, device=y_pred.device)[:, None, None]).permute(1, 2, 0)
except NotImplementedError: # resort to derive quantiles empirically
samples = torch.sort(self.sample(y_pred, n_samples), -1).values
quantiles = torch.quantile(samples, torch.tensor(quantiles, device=samples.device), dim=2).permute(1, 2, 0)
return quantiles
[docs]class NormalDistributionLoss(DistributionLoss):
"""
Normal distribution loss.
Requirements for original target normalizer:
* not normalized in log space (use :py:class:`~LogNormalDistributionLoss`)
* not coerced to be positive
"""
distribution_class = distributions.Normal
distribution_arguments = ["loc", "scale"]
[docs] def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Normal:
return self.distribution_class(loc=x[..., 0], scale=x[..., 1])
[docs] def rescale_parameters(
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
) -> torch.Tensor:
assert encoder.transformation not in ["log", "log1p"], "Use LogNormalDistributionLoss for log scaled data"
assert encoder.transformation not in [
"softplus",
"relu",
], "Cannot use NormalDistributionLoss for positive data"
assert encoder.transformation not in ["logit"], "Cannot use bound transformation such as 'logit'"
loc = encoder(dict(prediction=parameters[..., 0], target_scale=target_scale))
scale = F.softplus(parameters[..., 1]) * target_scale[..., 1].unsqueeze(1)
return torch.stack([loc, scale], dim=-1)
[docs]class NegativeBinomialDistributionLoss(DistributionLoss):
"""
Negative binomial loss, e.g. for count data.
Requirements for original target normalizer:
* not centered normalization (only rescaled)
"""
distribution_class = distributions.NegativeBinomial
distribution_arguments = ["mean", "shape"]
[docs] def map_x_to_distribution(self, x: torch.Tensor) -> distributions.NegativeBinomial:
mean = x[..., 0]
shape = x[..., 1]
r = 1.0 / shape
p = mean / (mean + r)
return self.distribution_class(total_count=r, probs=p)
[docs] def rescale_parameters(
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
) -> torch.Tensor:
assert not encoder.center, "NegativeBinomialDistributionLoss is not compatible with `center=True` normalization"
assert encoder.transformation not in ["logit"], "Cannot use bound transformation such as 'logit'"
if encoder.transformation in ["log", "log1p"]:
mean = torch.exp(parameters[..., 0] * target_scale[..., 1].unsqueeze(-1))
shape = (
F.softplus(torch.exp(parameters[..., 1]))
/ torch.exp(target_scale[..., 1].unsqueeze(-1)).sqrt() # todo: is this correct?
)
else:
mean = F.softplus(parameters[..., 0]) * target_scale[..., 1].unsqueeze(-1)
shape = F.softplus(parameters[..., 1]) / target_scale[..., 1].unsqueeze(-1).sqrt()
return torch.stack([mean, shape], dim=-1)
[docs] def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction. In the case of this distribution prediction we
need to derive the mean (as a point prediction) from the distribution parameters
Args:
y_pred: prediction output of network
in this case the two parameters for the negative binomial
Returns:
torch.Tensor: mean prediction
"""
return y_pred[..., 0]
[docs]class LogNormalDistributionLoss(DistributionLoss):
"""
Log-normal loss.
Requirements for original target normalizer:
* normalized target in log space
"""
distribution_class = distributions.LogNormal
distribution_arguments = ["loc", "scale"]
[docs] def map_x_to_distribution(self, x: torch.Tensor) -> distributions.LogNormal:
return self.distribution_class(loc=x[..., 0], scale=x[..., 1])
[docs] def rescale_parameters(
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
) -> torch.Tensor:
assert isinstance(encoder.transformation, str) and encoder.transformation in [
"log",
"log1p",
], f"Log distribution requires log scaling but found `transformation={encoder.transform}`"
assert encoder.transformation not in ["logit"], "Cannot use bound transformation such as 'logit'"
scale = F.softplus(parameters[..., 1]) * target_scale[..., 1].unsqueeze(-1)
loc = parameters[..., 0] * target_scale[..., 1].unsqueeze(-1) + target_scale[..., 0].unsqueeze(-1)
return torch.stack([loc, scale], dim=-1)
[docs]class BetaDistributionLoss(DistributionLoss):
"""
Beta distribution loss for unit interval data.
Requirements for original target normalizer:
* logit transformation
"""
distribution_class = distributions.Beta
distribution_arguments = ["mean", "shape"]
eps = 1e-4
[docs] def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Beta:
mean = x[..., 0]
shape = x[..., 1]
return self.distribution_class(concentration0=(1 - mean) * shape, concentration1=mean * shape)
[docs] def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
"""
Calculate negative likelihood
Args:
y_pred: network output
y_actual: actual values
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
distribution = self.map_x_to_distribution(y_pred)
# clip y_actual to avoid infinite losses
loss = -distribution.log_prob(y_actual.clip(self.eps, 1 - self.eps))
return loss
[docs] def rescale_parameters(
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
) -> torch.Tensor:
assert encoder.transformation in ["logit"], "Beta distribution is only compatible with logit transformation"
assert encoder.center, "Beta distribution requires normalizer to center data"
scaled_mean = encoder(dict(prediction=parameters[..., 0], target_scale=target_scale))
# need to first transform target scale standard deviation in logit space to real space
# we assume a normal distribution in logit space (we used a logit transform and a standard scaler)
# and know that the variance of the beta distribution is limited by `scaled_mean * (1 - scaled_mean)`
scaled_mean = scaled_mean * (1 - 2 * self.eps) + self.eps # ensure that mean is not exactly 0 or 1
mean_derivative = scaled_mean * (1 - scaled_mean)
# we can approximate variance as
# torch.pow(torch.tanh(target_scale[..., 1].unsqueeze(1) * torch.sqrt(mean_derivative)), 2) * mean_derivative
# shape is (positive) parameter * mean_derivative / var
shape_scaler = (
torch.pow(torch.tanh(target_scale[..., 1].unsqueeze(1) * torch.sqrt(mean_derivative)), 2) + self.eps
)
scaled_shape = F.softplus(parameters[..., 1]) / shape_scaler
return torch.stack([scaled_mean, scaled_shape], dim=-1)