parametrize_with_checks#
- pytorch_forecasting.utils._estimator_checks.parametrize_with_checks(objs, obj_varname='obj', check_varname='test_name')[source]#
Pytest specific decorator for parametrizing estimator checks.
Designed for setting up API compliance checks in compatible 2nd and 3rd party libraries, using
pytest.mark.parametrize.Inspired by the
sklearnutility of the same name.- Parameters:
objs (objects class or instance, or list thereof) – Objects to generate test names for.
obj_varname (str, optional, default = 'obj') – Name of the variable for objects to use in the parametrization.
check_varname (str, optional, default = 'test_name') – Name of the variable for test name strings to use in the parametrization.
- Returns:
decorator
- Return type:
pytest.mark.parametrize
See also
check_estimatorCheck if estimator adheres to pytorch-forecasting API contracts.
Examples
>>> from pytorch_forecasting.utils import parametrize_with_checks >>> from pytorch_forecasting.models import DecoderMLP, NBeats
>>> @parametrize_with_checks(NBeats, obj_varname='estimator') ... def test_sktime_compatible_estimator(estimator, test_name): ... check_estimator(estimator, tests_to_run=test_name, raise_exceptions=True)
>>> @parametrize_with_checks([NBeats, DecoderMLP]) ... def test_sktime_compatible_estimators(obj, test_name): ... check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)