TorchMetricWrapper#
- class pytorch_forecasting.metrics.base_metrics.TorchMetricWrapper(torchmetric: Metric, reduction: str | None = None, **kwargs)[source]#
Bases:
Metric
Wrap a torchmetric to work with PyTorch Forecasting.
Does not support weighting of errors and only supports metrics for point predictions.
- Parameters:
torchmetric (LightningMetric) – Torchmetric to wrap.
reduction (str, optional) – use reduction with torchmetric directly. Defaults to None.
Methods
compute
()Abstract method that calcualtes metric
forward
(y_pred, target, **kwargs)Aggregate and evaluate batch input directly.
persistent
([mode])Change post-init if metric states should be saved to its state_dict.
reset
()Reset metric state variables to their default value.
update
(y_pred, target, **kwargs)Override this method to update the state variables of your metric class.
- compute()[source]#
Abstract method that calcualtes metric
Should be overriden in derived classes
- Parameters:
y_pred – network output
y_actual – actual values
- Returns:
metric value on which backpropagation can be applied
- Return type:
torch.Tensor
- forward(y_pred, target, **kwargs)[source]#
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Parameters:
args – Any arguments as required by the metric
update
method.kwargs – Any keyword arguments as required by the metric
update
method.
- Returns:
The output of the
compute
method evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forward
is called again.