Source code for pytorch_forecasting.models.base._base_object

"""Base Classes for pytorch-forecasting models, skbase compatible for indexing."""

import inspect

from pytorch_forecasting.base._base_object import _BaseObject


[docs] class _BasePtForecaster_Common(_BaseObject): """Base class for all PyTorch Forecasting forecaster packages. This class points to model objects and contains metadata as tags. """
[docs] @classmethod def get_cls(cls): """Get model class.""" raise NotImplementedError
[docs] @classmethod def name(cls): """Get model name.""" name = cls.get_class_tags().get("info:name", None) if name is None: name = cls.get_model_cls().__name__ return name
[docs] @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. Parameters ---------- parameter_set : str, default="default" Name of the set of test parameters to return, for use in tests. If no special parameters are defined for a value, will return `"default"` set. Returns ------- instance : instance of the class with default parameters """ if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: params = cls.get_test_params(parameter_set=parameter_set) else: params = cls.get_test_params() if isinstance(params, list) and isinstance(params[0], dict): params = params[0] elif isinstance(params, dict): pass else: raise TypeError( "get_test_params should either return a dict or list of dict." ) return cls.get_model_cls()(**params)
[docs] @classmethod def create_test_instances_and_names(cls, parameter_set="default"): """Create list of all test instances and a list of names for them. Parameters ---------- parameter_set : str, default="default" Name of the set of test parameters to return, for use in tests. If no special parameters are defined for a value, will return `"default"` set. Returns ------- objs : list of instances of cls i-th instance is ``cls(**cls.get_test_params()[i])`` names : list of str, same length as objs i-th element is name of i-th instance of obj in tests. The naming convention is ``{cls.__name__}-{i}`` if more than one instance, otherwise ``{cls.__name__}`` """ if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: param_list = cls.get_test_params(parameter_set=parameter_set) else: param_list = cls.get_test_params() objs = [] if not isinstance(param_list, dict | list): raise RuntimeError( f"Error in {cls.__name__}.get_test_params, " "return must be param dict for class, or list thereof" ) if isinstance(param_list, dict): param_list = [param_list] for params in param_list: if not isinstance(params, dict): raise RuntimeError( f"Error in {cls.__name__}.get_test_params, " "return must be param dict for class, or list thereof" ) objs += [cls.get_model_cls()(**params)] num_instances = len(param_list) if num_instances > 1: names = [cls.__name__ + "-" + str(i) for i in range(num_instances)] else: names = [cls.__name__] return objs, names
[docs] class _BasePtForecaster(_BasePtForecaster_Common): """Base class for PyTorch Forecasting v1 forecasters.""" _tags = { "object_type": ["forecaster_pytorch", "forecaster_pytorch_v1"], }
[docs] class _BasePtForecasterV2(_BasePtForecaster_Common): """Base class for PyTorch Forecasting v2 forecasters.""" _tags = { "object_type": "forecaster_pytorch_v2", }