AggregationMetric#

class pytorch_forecasting.metrics.base_metrics.AggregationMetric(metric: Metric, **kwargs)[source]#

Bases: Metric

Calculate metric on mean prediction and actuals.

Parameters:

metric (Metric) – metric which to calculate on aggreation.

Methods

compute()

Abstract method that calcualtes metric

forward(y_pred, y_actual, **kwargs)

Calculate composite metric

persistent([mode])

Change post-init if metric states should be saved to its state_dict.

reset()

Reset metric state variables to their default value.

update(y_pred, y_actual, **kwargs)

Calculate composite metric

compute()[source]#

Abstract method that calcualtes metric

Should be overriden in derived classes

Parameters:
  • y_pred – network output

  • y_actual – actual values

Returns:

metric value on which backpropagation can be applied

Return type:

torch.Tensor

forward(y_pred: Tensor, y_actual: Tensor, **kwargs)[source]#

Calculate composite metric

Parameters:
  • y_pred – network output

  • y_actual – actual values

  • **kwargs – arguments to update function

Returns:

metric value on which backpropagation can be applied

Return type:

torch.Tensor

persistent(mode: bool = False) None[source]#

Change post-init if metric states should be saved to its state_dict.

reset() None[source]#

Reset metric state variables to their default value.

update(y_pred: Tensor, y_actual: Tensor, **kwargs) Tensor[source]#

Calculate composite metric

Parameters:
  • y_pred – network output

  • y_actual – actual values

Returns:

metric value on which backpropagation can be applied

Return type:

torch.Tensor