AggregationMetric#
- class pytorch_forecasting.metrics.base_metrics.AggregationMetric(metric: Metric, **kwargs)[source]#
Bases:
Metric
Calculate metric on mean prediction and actuals.
- Parameters:
metric (Metric) – metric which to calculate on aggreation.
Methods
compute
()Abstract method that calcualtes metric
forward
(y_pred, y_actual, **kwargs)Calculate composite metric
persistent
([mode])Method for post-init to change if metric states should be saved to its state_dict.
reset
()This method automatically resets the metric state variables to their default value.
update
(y_pred, y_actual, **kwargs)Calculate composite metric
- compute()[source]#
Abstract method that calcualtes metric
Should be overriden in derived classes
- Parameters:
y_pred – network output
y_actual – actual values
- Returns:
metric value on which backpropagation can be applied
- Return type:
torch.Tensor
- forward(y_pred: Tensor, y_actual: Tensor, **kwargs)[source]#
Calculate composite metric
- Parameters:
y_pred – network output
y_actual – actual values
**kwargs – arguments to update function
- Returns:
metric value on which backpropagation can be applied
- Return type:
torch.Tensor
- persistent(mode: bool = False) None [source]#
Method for post-init to change if metric states should be saved to its state_dict.