TorchMetricWrapper

class pytorch_forecasting.metrics.TorchMetricWrapper(torchmetric: torchmetrics.metric.Metric, reduction: Optional[str] = None, **kwargs)[source]

Bases: pytorch_forecasting.metrics.Metric

Wrap a torchmetric to work with PyTorch Forecasting.

Does not support weighting of errors and only supports metrics for point predictions.

Parameters
  • torchmetric (LightningMetric) – Torchmetric to wrap.

  • reduction (str, optional) – use reduction with torchmetric directly. Defaults to None.

Methods

compute()

Abstract method that calcualtes metric

forward(y_pred, target, **kwargs)

Automatically calls update().

persistent([mode])

Method for post-init to change if metric states should be saved to its state_dict.

reset()

This method automatically resets the metric state variables to their default value.

update(y_pred, target, **kwargs)

Override this method to update the state variables of your metric class.

compute()[source]

Abstract method that calcualtes metric

Should be overriden in derived classes

Parameters
  • y_pred – network output

  • y_actual – actual values

Returns

metric value on which backpropagation can be applied

Return type

torch.Tensor

forward(y_pred, target, **kwargs)[source]

Automatically calls update().

Returns the metric value over inputs if compute_on_step is True.

persistent(mode: bool = False) None[source]

Method for post-init to change if metric states should be saved to its state_dict.

reset() None[source]

This method automatically resets the metric state variables to their default value.

update(y_pred: torch.Tensor, target: torch.Tensor, **kwargs) torch.Tensor[source]

Override this method to update the state variables of your metric class.