Source code for pytorch_forecasting.metrics.base_metrics._base_object
"""Base object class for pytorch-forecasting metrics."""
from pytorch_forecasting.base._base_object import _BaseObject
[docs]
class _BasePtMetric(_BaseObject):
"""Base class for metric object that can be discovered for testing."""
_tags = {"object_type": "metric"}
[docs]
@classmethod
def name(cls):
"""Get the name of the metric.
Returns
-------
str
The name of the metric.
"""
metric_cls = cls.get_cls()
return metric_cls.__name__
[docs]
@classmethod
def get_cls(cls):
"""Get the metric class.
Returns
-------
type
The metric class.
"""
raise NotImplementedError("get_cls must be implemented in subclasses.")
[docs]
@classmethod
def get_metric_test_params(cls):
"""Returns parameters for initializing the metric for testing.
Returns
-------
dict
Dictionary containing parameters for initializing the metric.d
"""
return []
[docs]
@classmethod
def get_encoder(cls):
"""Get the encoder for the metric.
This can be overridden by subclasses to provide a specific encoder.
Returns
-------
TorchNormalizer
An instance of TorchNormalizer or similar encoder.
"""
from pytorch_forecasting.data import TorchNormalizer
return TorchNormalizer()