Source code for pytorch_forecasting.utils._estimator_checks

# copyright: pytorch-forecasting developers, BSD-3-Clause License (see LICENSE file)
# mostly based on the sktime utility of the same name (BSD-3 Clause)
# which in turn was inspired by the scikit-learn utility of the same name
"""Estimator checker for extension."""

__author__ = ["fkiraly"]
__all__ = ["check_estimator"]


[docs] def check_estimator( estimator, raise_exceptions=False, tests_to_run=None, fixtures_to_run=None, verbose=True, tests_to_exclude=None, fixtures_to_exclude=None, ): """Run all tests on one single estimator or pytorch-forecasting object. This utility runs all tests from the unified API conformance suites applying to the estimator, including tests for the specific subtype and all supertypes. If ``estimator`` is an instance, tests are run on the specific instance and its class; if ``estimator`` is a class, tests are run on the class, and all instances constructed via its ``create_test_instances_and_names`` method. For packaged objects such as neural network models, fetches the package class via the ``pkg`` attribute and also runs all tests on the package class. Parameters ---------- estimator : estimator class or estimator instance can be any object from ``pytorch-forecasting`` for which suite tests exist. raise_exceptions : bool, optional, default=False whether to return exceptions/failures in the results dict, or raise them * if False: returns exceptions in returned ``results`` dict * if True: raises exceptions as they occur tests_to_run : str or list of str, optional. Default = run all tests. Names (test/function name string) of tests to run. sub-sets tests that are run to the tests given here. fixtures_to_run : str or list of str, optional. Default = run all tests. pytest test-fixture combination codes, which test-fixture combinations to run. sub-sets tests and fixtures to run to the list given here. If both tests_to_run and fixtures_to_run are provided, runs the *union*, i.e., all test-fixture combinations for tests in tests_to_run, plus all test-fixture combinations in fixtures_to_run. verbose : int or bool, optional, default=1. verbosity level for printouts from tests run. * 0 or False: no printout * 1 or True (default): print summary of test run, but no print from tests * 2: print all test output, including output from within the tests tests_to_exclude : str or list of str, names of tests to exclude. default = None removes tests that should not be run, after subsetting via tests_to_run. fixtures_to_exclude : str or list of str, fixtures to exclude. default = None removes test-fixture combinations that should not be run. This is done after subsetting via fixtures_to_run. Returns ------- results : dict dictionary of results of the tests in self keys are test/fixture strings, identical as in pytest, e.g., ``test[fixture]``; entries are the string ``"PASSED"`` if the test passed, or the exception raised if the test did not pass. returned only if all tests pass, or ``raise_exceptions=False`` Raises ------ if ``raise_exceptions=True``, raises any exception produced by the tests directly Examples -------- >>> from pytorch_forecasting.models import NBeats >>> from pytorch_forecasting.utils import check_estimator Running all tests for NBeats class, this uses all instances from get_test_params and compatible scenarios >>> results = check_estimator(NBeats) All tests PASSED! Running specific test (all fixtures) for NBeats >>> results = check_estimator(NBeats, tests_to_run="test_pkg_linkage") All tests PASSED! {'test_pkg_linkage[NBeats-0]': 'PASSED', 'test_pkg_linkage[NBeats-1]': 'PASSED'} Running one specific test-fixture-combination for NBeats >>> check_estimator( ... NBeats, fixtures_to_run="test_pkg_linkage[NBeats_pkg-NBeats]" ... ) All tests PASSED! {'test_pkg_linkage[NBeats_pkg-NBeats]': 'PASSED'} """ from skbase.utils.dependencies import _check_soft_dependencies PKG_NAME = "pytorch-forecasting" msg = ( "check_estimator is a testing utility for developers, and " "requires pytest to be present " "in the python environment, but pytest was not found. " "pytest is a developer dependency and not included in the base " f"{PKG_NAME} installation. Please run: `pip install pytest` to " "install the pytest package. " f"To install {PKG_NAME} with all developer dependencies, run:" f" `pip install {PKG_NAME}[dev]`" ) _check_soft_dependencies("pytest", msg=msg) from pytorch_forecasting.tests.test_class_register import get_test_classes_for_obj test_clss_for_est = get_test_classes_for_obj(estimator) results = {} for test_cls in test_clss_for_est: test_cls_results = test_cls().run_tests( obj=estimator, raise_exceptions=raise_exceptions, tests_to_run=tests_to_run, fixtures_to_run=fixtures_to_run, tests_to_exclude=tests_to_exclude, fixtures_to_exclude=fixtures_to_exclude, verbose=verbose if raise_exceptions else False, ) results.update(test_cls_results) failed_tests = [key for key in results.keys() if results[key] != "PASSED"] if len(failed_tests) > 0: msg = failed_tests msg = ["FAILED: " + x for x in msg] msg = "\n".join(msg) else: msg = "All tests PASSED!" if int(verbose) > 0: # printing is an intended feature, for console usage and interactive debugging print(msg) # noqa T001 return results
[docs] def _get_test_names_from_class(test_cls): """Get all test names from a test class. Parameters ---------- test_cls : class class of the test Returns ------- test_names : list of str list of test names """ test_names = [attr for attr in dir(test_cls) if attr.startswith("test")] return test_names
[docs] def _get_test_names_for_obj(obj): """Get all test names for an object. Parameters ---------- obj : object object to get tests for Returns ------- test_names : list of str list of test names """ from pytorch_forecasting.tests.test_class_register import get_test_classes_for_obj test_clss_for_obj = get_test_classes_for_obj(obj) test_names = [] for test_cls in test_clss_for_obj: test_names.extend(_get_test_names_from_class(test_cls)) return test_names
[docs] def parametrize_with_checks(objs, obj_varname="obj", check_varname="test_name"): """Pytest specific decorator for parametrizing estimator checks. Designed for setting up API compliance checks in compatible 2nd and 3rd party libraries, using ``pytest.mark.parametrize``. Inspired by the ``sklearn`` utility of the same name. Parameters ---------- objs : objects class or instance, or list thereof Objects to generate test names for. obj_varname : str, optional, default = 'obj' Name of the variable for objects to use in the parametrization. check_varname : str, optional, default = 'test_name' Name of the variable for test name strings to use in the parametrization. Returns ------- decorator : `pytest.mark.parametrize` See Also -------- check_estimator : Check if estimator adheres to pytorch-forecasting API contracts. Examples -------- >>> from pytorch_forecasting.utils import parametrize_with_checks >>> from pytorch_forecasting.models import DecoderMLP, NBeats >>> @parametrize_with_checks(NBeats, obj_varname='estimator') ... def test_sktime_compatible_estimator(estimator, test_name): ... check_estimator(estimator, tests_to_run=test_name, raise_exceptions=True) >>> @parametrize_with_checks([NBeats, DecoderMLP]) ... def test_sktime_compatible_estimators(obj, test_name): ... check_estimator(obj, tests_to_run=test_name, raise_exceptions=True) """ import pytest if not isinstance(objs, list): objs = [objs] test_names = [] for obj in objs: tests_for_obj = _get_test_names_for_obj(obj) test_names.extend([(obj, test) for test in tests_for_obj]) var_str = f"{obj_varname}, {check_varname}" return pytest.mark.parametrize(var_str, test_names)