AggregationMetric#

class pytorch_forecasting.metrics.base_metrics.AggregationMetric(metric: Metric, **kwargs)[source]#

Bases: Metric

Calculate metric on mean prediction and actuals.

Parameters

metric (Metric) – metric which to calculate on aggreation.

Methods

compute()

Abstract method that calcualtes metric

forward(y_pred, y_actual, **kwargs)

Calculate composite metric

persistent([mode])

Method for post-init to change if metric states should be saved to its state_dict.

reset()

This method automatically resets the metric state variables to their default value.

update(y_pred, y_actual, **kwargs)

Calculate composite metric

compute()[source]#

Abstract method that calcualtes metric

Should be overriden in derived classes

Parameters
  • y_pred – network output

  • y_actual – actual values

Returns

metric value on which backpropagation can be applied

Return type

torch.Tensor

forward(y_pred: Tensor, y_actual: Tensor, **kwargs)[source]#

Calculate composite metric

Parameters
  • y_pred – network output

  • y_actual – actual values

  • **kwargs – arguments to update function

Returns

metric value on which backpropagation can be applied

Return type

torch.Tensor

persistent(mode: bool = False) None[source]#

Method for post-init to change if metric states should be saved to its state_dict.

reset() None[source]#

This method automatically resets the metric state variables to their default value.

update(y_pred: Tensor, y_actual: Tensor, **kwargs) Tensor[source]#

Calculate composite metric

Parameters
  • y_pred – network output

  • y_actual – actual values

Returns

metric value on which backpropagation can be applied

Return type

torch.Tensor