TorchMetricWrapper#

class pytorch_forecasting.metrics.base_metrics.TorchMetricWrapper(torchmetric: Metric, reduction: Optional[str] = None, **kwargs)[source]#

Bases: Metric

Wrap a torchmetric to work with PyTorch Forecasting.

Does not support weighting of errors and only supports metrics for point predictions.

Parameters
  • torchmetric (LightningMetric) – Torchmetric to wrap.

  • reduction (str, optional) – use reduction with torchmetric directly. Defaults to None.

Methods

compute()

Abstract method that calcualtes metric

forward(y_pred, target, **kwargs)

forward serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state.

persistent([mode])

Method for post-init to change if metric states should be saved to its state_dict.

reset()

This method automatically resets the metric state variables to their default value.

update(y_pred, target, **kwargs)

Override this method to update the state variables of your metric class.

compute()[source]#

Abstract method that calcualtes metric

Should be overriden in derived classes

Parameters
  • y_pred – network output

  • y_actual – actual values

Returns

metric value on which backpropagation can be applied

Return type

torch.Tensor

forward(y_pred, target, **kwargs)[source]#

forward serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state.

Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

persistent(mode: bool = False) None[source]#

Method for post-init to change if metric states should be saved to its state_dict.

reset() None[source]#

This method automatically resets the metric state variables to their default value.

update(y_pred: Tensor, target: Tensor, **kwargs) Tensor[source]#

Override this method to update the state variables of your metric class.