TorchMetricWrapper#

class pytorch_forecasting.metrics.base_metrics.TorchMetricWrapper(torchmetric: Metric, reduction: str | None = None, **kwargs)[source]#

Bases: Metric

Wrap a torchmetric to work with PyTorch Forecasting.

Does not support weighting of errors and only supports metrics for point predictions.

Parameters:
  • torchmetric (LightningMetric) – Torchmetric to wrap.

  • reduction (str, optional) – use reduction with torchmetric directly. Defaults to None.

Methods

compute()

Abstract method that calcualtes metric

forward(y_pred, target, **kwargs)

Aggregate and evaluate batch input directly.

persistent([mode])

Change post-init if metric states should be saved to its state_dict.

reset()

Reset metric state variables to their default value.

update(y_pred, target, **kwargs)

Override this method to update the state variables of your metric class.

compute()[source]#

Abstract method that calcualtes metric

Should be overriden in derived classes

Parameters:
  • y_pred – network output

  • y_actual – actual values

Returns:

metric value on which backpropagation can be applied

Return type:

torch.Tensor

forward(y_pred, target, **kwargs)[source]#

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Parameters:
  • args – Any arguments as required by the metric update method.

  • kwargs – Any keyword arguments as required by the metric update method.

Returns:

The output of the compute method evaluated on the current batch.

Raises:

TorchMetricsUserError – If the metric is already synced and forward is called again.

persistent(mode: bool = False) None[source]#

Change post-init if metric states should be saved to its state_dict.

reset() None[source]#

Reset metric state variables to their default value.

update(y_pred: Tensor, target: Tensor, **kwargs) Tensor[source]#

Override this method to update the state variables of your metric class.