TorchMetricWrapper¶
- class pytorch_forecasting.metrics.TorchMetricWrapper(torchmetric: torchmetrics.metric.Metric, reduction: Optional[str] = None, **kwargs)[source]¶
Bases:
pytorch_forecasting.metrics.MetricWrap a torchmetric to work with PyTorch Forecasting.
Does not support weighting of errors and only supports metrics for point predictions.
- Parameters
torchmetric (LightningMetric) – Torchmetric to wrap.
reduction (str, optional) – use reduction with torchmetric directly. Defaults to None.
Methods
compute()Abstract method that calcualtes metric
forward(y_pred, target, **kwargs)Automatically calls
update().persistent([mode])Method for post-init to change if metric states should be saved to its state_dict.
reset()This method automatically resets the metric state variables to their default value.
update(y_pred, target, **kwargs)Override this method to update the state variables of your metric class.
- compute()[source]¶
Abstract method that calcualtes metric
Should be overriden in derived classes
- Parameters
y_pred – network output
y_actual – actual values
- Returns
metric value on which backpropagation can be applied
- Return type
torch.Tensor
- forward(y_pred, target, **kwargs)[source]¶
Automatically calls
update().Returns the metric value over inputs if
compute_on_stepis True.
- persistent(mode: bool = False) None[source]¶
Method for post-init to change if metric states should be saved to its state_dict.