parametrize_with_checks#

pytorch_forecasting.utils._estimator_checks.parametrize_with_checks(objs, obj_varname='obj', check_varname='test_name')[source]#

Pytest specific decorator for parametrizing estimator checks.

Designed for setting up API compliance checks in compatible 2nd and 3rd party libraries, using pytest.mark.parametrize.

Inspired by the sklearn utility of the same name.

Parameters:
  • objs (objects class or instance, or list thereof) – Objects to generate test names for.

  • obj_varname (str, optional, default = 'obj') – Name of the variable for objects to use in the parametrization.

  • check_varname (str, optional, default = 'test_name') – Name of the variable for test name strings to use in the parametrization.

Returns:

decorator

Return type:

pytest.mark.parametrize

See also

check_estimator

Check if estimator adheres to pytorch-forecasting API contracts.

Examples

>>> from pytorch_forecasting.utils import parametrize_with_checks
>>> from pytorch_forecasting.models import DecoderMLP, NBeats
>>> @parametrize_with_checks(NBeats, obj_varname='estimator')
... def test_sktime_compatible_estimator(estimator, test_name):
...     check_estimator(estimator, tests_to_run=test_name, raise_exceptions=True)
>>> @parametrize_with_checks([NBeats, DecoderMLP])
... def test_sktime_compatible_estimators(obj, test_name):
...     check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)