pytorch_forecasting.data.data_module._tslib_data_module.TslibDataModule#
- class pytorch_forecasting.data.data_module._tslib_data_module.TslibDataModule(time_series_dataset: TimeSeries, context_length: int, prediction_length: int, freq: str = 'h', add_relative_time_idx: bool = False, add_target_scales: bool = False, target_normalizer: TorchNormalizer | EncoderNormalizer | NaNLabelEncoder | str | list[TorchNormalizer | EncoderNormalizer | NaNLabelEncoder] | tuple[TorchNormalizer | EncoderNormalizer | NaNLabelEncoder] | None = 'auto', scalers: dict[str, StandardScaler | RobustScaler | TorchNormalizer | EncoderNormalizer] | None = None, shuffle: bool = True, window_stride: int = 1, batch_size: int = 32, num_workers: int = 0, train_val_test_split: tuple[float, float, float] = (0.7, 0.15, 0.15), collate_fn: Callable | None = None, **kwargs)[source]#
Experimental data module for integrating tslib time series into PyTorch Forecasting.
This module serves as the D2 layer for tslib models including transformer-based architectures like Informer, AutoFormer, TimeXer and other model deep learning model architectures.
- Parameters:
time_series_dataset (TimeSeries) – The time series dataset to be used for training and validation. This is the newly implemented D1 layer.
context_length (int) – The length of the context window for the model. This is the number of time steps used as input to the model.
prediction_length (int) – The length of the prediction window for the model. This is the number of time steps to be predicted by the model.
freq (str, default = "h") – The frequency of the time series data. This is used to determine the time steps for the model.
features (str = "MS") –
- Feature combination mode:
”S”: Single variable forecasting (target only)
”M”: Multivariate forecasting, using all variables
”MS”: Multivariate to single, using all variables to predict target
add_relative_time_idx (bool = False) – Whether to allow the relative time index to be used with the model.
add_target_scales (bool = False) – Whether to add target scaling info.
target_normalizer –
- Union[NORMALIZER, str, list[NORMALIZER], tuple[NORMALIZER], None],
default=”auto”
Normalizer for the target variable. If “auto”, uses RobustScaler.
scalers (Optional[dict[str, Union[StandardScaler, RobustScaler, TorchNormalizer]]], default=None #noqa: E501) – Dictionary of feature scalers.
shuffle (bool, default=True) – Whether to shuffle the data at every epoch.
window_stride (int, default=1) – The stride for the sliding window. This is used to create overlapping windows for the data.
batch_size (int, default=32) – Batch size for dataloader.
num_workers (int, default=0) – Number of workers for dataloader.
train_val_test_split (tuple, default=(0.7, 0.15, 0.15)) – Proportions for train, validation, and test dataset splits.
collate_fn (Optional[callable], default=None) – Custom collate function for the dataloader.
- prepare_data_per_node#
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- Type:
bool
- allow_zero_length_dataloader_with_multiple_devices#
If True, dataloader with zero length within local rank is allowed. Default value is False.
- Type:
bool
- __init__(time_series_dataset: TimeSeries, context_length: int, prediction_length: int, freq: str = 'h', add_relative_time_idx: bool = False, add_target_scales: bool = False, target_normalizer: TorchNormalizer | EncoderNormalizer | NaNLabelEncoder | str | list[TorchNormalizer | EncoderNormalizer | NaNLabelEncoder] | tuple[TorchNormalizer | EncoderNormalizer | NaNLabelEncoder] | None = 'auto', scalers: dict[str, StandardScaler | RobustScaler | TorchNormalizer | EncoderNormalizer] | None = None, shuffle: bool = True, window_stride: int = 1, batch_size: int = 32, num_workers: int = 0, train_val_test_split: tuple[float, float, float] = (0.7, 0.15, 0.15), collate_fn: Callable | None = None, **kwargs) None[source]#
- prepare_data_per_node#
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices#
If True, dataloader with zero length within local rank is allowed. Default value is False.
Methods
__delattr__(name, /)Implement delattr(self, name).
__dir__()Default dir() implementation.
__eq__(value, /)Return self==value.
__format__(format_spec, /)Default object formatter.
__ge__(value, /)Return self>=value.
__getattribute__(name, /)Return getattr(self, name).
__getstate__()Helper for pickle.
__gt__(value, /)Return self>value.
__hash__()Return hash(self).
__init_subclass__This method is called when a class is subclassed.
__le__(value, /)Return self<=value.
__lt__(value, /)Return self<value.
__ne__(value, /)Return self!=value.
__new__(*args, **kwargs)__reduce__()Helper for pickle.
__reduce_ex__(protocol, /)Helper for pickle.
__repr__()Return repr(self).
__setattr__(name, value, /)Implement setattr(self, name, value).
__sizeof__()Size of object in memory, in bytes.
__str__()Return a string representation of the datasets that are set up.
__subclasshook__Abstract classes can override this to customize issubclass().
_create_windows(indices)Create windows for the data in the given indices, for training, testing and validation.
_prepare_metadata()Prepare metadata for tslib time series data module.
_preprocess_data(idx)Process the the time series data at the given index, before feeding it to the _TslibDataset class.
_set_hparams(hp)_to_hparams_dict(hp)_validate_indices()Validate that we have meaningful features for training.
collate_fn(batch)Custom collate function for the dataloader.
from_datasets([train_dataset, val_dataset, ...])Create an instance from torch.utils.data.Dataset.
load_from_checkpoint(checkpoint_path[, ...])Primary way of loading a datamodule from a checkpoint.
load_state_dict(state_dict)Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
on_after_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_exception(exception)Called when the trainer execution is interrupted by an exception.
predict_dataloader()Create the prediction dataloader.
prepare_data()Use this to download and prepare data.
remove_ignored_hparams(ignore_list)Remove ignored hyperparameters from the stored state.
save_hyperparameters(*args[, ignore, frame, ...])Save arguments to
hparamsattribute.setup([stage])Setup the data module by preparing the datasets for training, testing and validation.
state_dict()Called when saving a checkpoint, implement to generate and save datamodule state.
teardown(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader()Create the test dataloader.
train_dataloader()Create the train dataloader.
transfer_batch_to_device(batch, device, ...)Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.val_dataloader()Create the validation dataloader.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPE__annotations____dict____doc____jit_unused_properties____module____weakref__list of weak references to the object
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().metadataname