Source code for pytorch_forecasting.data.data_module

#######################################################################################
# Disclaimer: This data-module is still work in progress and experimental, please
# use with care. This data-module is a basic skeleton of how the data-handling pipeline
# may look like in the future.
# This is D2 layer that will handle the preprocessing and data loaders.
# For now, this pipeline handles the simplest situation: The whole data can be loaded
# into the memory.
#######################################################################################

from typing import Any, Optional, Union
from warnings import warn

from lightning.pytorch import LightningDataModule
from sklearn.preprocessing import RobustScaler, StandardScaler
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_forecasting.data.encoders import (
    EncoderNormalizer,
    NaNLabelEncoder,
    TorchNormalizer,
)
from pytorch_forecasting.data.timeseries import TimeSeries
from pytorch_forecasting.utils._coerce import _coerce_to_dict

NORMALIZER = TorchNormalizer | EncoderNormalizer | NaNLabelEncoder


[docs] class EncoderDecoderTimeSeriesDataModule(LightningDataModule): """ Lightning DataModule for processing time series data in an encoder-decoder format. This module handles preprocessing, splitting, and batching of time series data for use in deep learning models. It supports categorical and continuous features, various scalers, and automatic target normalization. Parameters ---------- time_series_dataset : TimeSeries The dataset containing time series data. max_encoder_length : int, default=30 Maximum length of the encoder input sequence. min_encoder_length : Optional[int], default=None Minimum length of the encoder input sequence. Defaults to `max_encoder_length` if not specified. max_prediction_length : int, default=1 Maximum length of the decoder output sequence. min_prediction_length : Optional[int], default=None Minimum length of the decoder output sequence. Defaults to `max_prediction_length` if not specified. min_prediction_idx : Optional[int], default=None Minimum index from which predictions start. allow_missing_timesteps : bool, default=False Whether to allow missing timesteps in the dataset. add_relative_time_idx : bool, default=False Whether to add a relative time index feature. add_target_scales : bool, default=False Whether to add target scaling information. add_encoder_length : Union[bool, str], default="auto" Whether to include encoder length information. target_normalizer : Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], default="auto" Normalizer for the target variable. If "auto", uses `RobustScaler`. categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None Dictionary of categorical encoders. scalers : Optional[Dict[str, Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer]]], default=None Dictionary of feature scalers. randomize_length : Union[None, Tuple[float, float], bool], default=False Whether to randomize input sequence length. batch_size : int, default=32 Batch size for DataLoader. num_workers : int, default=0 Number of workers for DataLoader. train_val_test_split : tuple, default=(0.7, 0.15, 0.15) Proportions for train, validation, and test dataset splits. """ def __init__( self, time_series_dataset: TimeSeries, max_encoder_length: int = 30, min_encoder_length: int | None = None, max_prediction_length: int = 1, min_prediction_length: int | None = None, min_prediction_idx: int | None = None, allow_missing_timesteps: bool = False, add_relative_time_idx: bool = False, add_target_scales: bool = False, add_encoder_length: bool | str = "auto", target_normalizer: NORMALIZER | str | list[NORMALIZER] | tuple[NORMALIZER] | None = "auto", categorical_encoders: dict[str, NaNLabelEncoder] | None = None, scalers: dict[ str, StandardScaler | RobustScaler | TorchNormalizer | EncoderNormalizer ] | None = None, randomize_length: None | tuple[float, float] | bool = False, batch_size: int = 32, num_workers: int = 0, train_val_test_split: tuple = (0.7, 0.15, 0.15), ): self.time_series_dataset = time_series_dataset self.max_encoder_length = max_encoder_length self.min_encoder_length = min_encoder_length self.max_prediction_length = max_prediction_length self.min_prediction_length = min_prediction_length self.min_prediction_idx = min_prediction_idx self.allow_missing_timesteps = allow_missing_timesteps self.add_relative_time_idx = add_relative_time_idx self.add_target_scales = add_target_scales self.add_encoder_length = add_encoder_length self.randomize_length = randomize_length self.target_normalizer = target_normalizer self.categorical_encoders = categorical_encoders self.scalers = scalers self.batch_size = batch_size self.num_workers = num_workers self.train_val_test_split = train_val_test_split warn( "EncoderDecoderTimeSeriesDataModule is part of an experimental " "rework of the " "pytorch-forecasting data layer, " "scheduled for release with v2.0.0. " "The API is not stable and may change without prior warning. " "For beta testing, but not for stable production use. " "Feedback and suggestions are very welcome in " "pytorch-forecasting issue 1736, " "https://github.com/sktime/pytorch-forecasting/issues/1736", UserWarning, ) super().__init__() # handle defaults and derived attributes if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": self._target_normalizer = RobustScaler() else: self._target_normalizer = target_normalizer self.time_series_metadata = time_series_dataset.get_metadata() self._min_prediction_length = min_prediction_length or max_prediction_length self._min_encoder_length = min_encoder_length or max_encoder_length self._categorical_encoders = _coerce_to_dict(categorical_encoders) self._scalers = _coerce_to_dict(scalers) self.n_targets = len(self.time_series_metadata["cols"]["y"]) self.categorical_indices = [] self.continuous_indices = [] self._metadata = None for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): if self.time_series_metadata["col_type"].get(col) == "C": self.categorical_indices.append(idx) else: self.continuous_indices.append(idx) def _prepare_metadata(self): """Prepare metadata for model initialisation. Returns ------- dict dictionary containing the following keys: * ``encoder_cat``: Number of categorical variables in the encoder. Computed as ``len(self.categorical_indices)``, which counts the categorical feature indices. * ``encoder_cont``: Number of continuous variables in the encoder. Computed as ``len(self.continuous_indices)``, which counts the continuous feature indices. * ``decoder_cat``: Number of categorical variables in the decoder that are known in advance. Computed by filtering ``self.time_series_metadata["cols"]["x"]`` where col_type == "C"(categorical) and col_known == "K" (known) * ``decoder_cont``: Number of continuous variables in the decoder that are known in advance. Computed by filtering ``self.time_series_metadata["cols"]["x"]`` where col_type == "F"(continuous) and col_known == "K"(known) * ``target``: Number of target variables. Computed as ``len(self.time_series_metadata["cols"]["y"])``, which gives the number of output target columns.. * ``static_categorical_features``: Number of static categorical features Computed by filtering ``self.time_series_metadata["cols"]["st"]`` (static features) where col_type == "C" (categorical). * ``static_continuous_features``: Number of static continuous features Computed as difference of ``len(self.time_series_metadata["cols"]["st"])`` (static features) and static_categorical_features that gives static continuous feature * ``max_encoder_length``: maximum encoder length Taken directly from `self.max_encoder_length`. * ``max_prediction_length``: maximum prediction length Taken directly from `self.max_prediction_length`. * ``min_encoder_length``: minimum encoder length Taken directly from `self.min_encoder_length`. * ``min_prediction_length``: minimum prediction length Taken directly from `self.min_prediction_length`. """ encoder_cat_count = len(self.categorical_indices) encoder_cont_count = len(self.continuous_indices) decoder_cat_count = len( [ col for col in self.time_series_metadata["cols"]["x"] if self.time_series_metadata["col_type"].get(col) == "C" and self.time_series_metadata["col_known"].get(col) == "K" ] ) decoder_cont_count = len( [ col for col in self.time_series_metadata["cols"]["x"] if self.time_series_metadata["col_type"].get(col) == "F" and self.time_series_metadata["col_known"].get(col) == "K" ] ) target_count = len(self.time_series_metadata["cols"]["y"]) metadata = { "encoder_cat": encoder_cat_count, "encoder_cont": encoder_cont_count, "decoder_cat": decoder_cat_count, "decoder_cont": decoder_cont_count, "target": target_count, } if self.time_series_metadata["cols"]["st"]: static_cat_count = len( [ col for col in self.time_series_metadata["cols"]["st"] if self.time_series_metadata["col_type"].get(col) == "C" ] ) static_cont_count = ( len(self.time_series_metadata["cols"]["st"]) - static_cat_count ) metadata["static_categorical_features"] = static_cat_count metadata["static_continuous_features"] = static_cont_count else: metadata["static_categorical_features"] = 0 metadata["static_continuous_features"] = 0 metadata.update( { "max_encoder_length": self.max_encoder_length, "max_prediction_length": self.max_prediction_length, "min_encoder_length": self._min_encoder_length, "min_prediction_length": self._min_prediction_length, } ) return metadata @property def metadata(self): """Compute metadata for model initialization. This property returns a dictionary containing the shapes and key information related to the time series model. The metadata includes: * ``encoder_cat``: Number of categorical variables in the encoder. * ``encoder_cont``: Number of continuous variables in the encoder. * ``decoder_cat``: Number of categorical variables in the decoder that are known in advance. * ``decoder_cont``: Number of continuous variables in the decoder that are known in advance. * ``target``: Number of target variables. If static features are present, the following keys are added: * ``static_categorical_features``: Number of static categorical features * ``static_continuous_features``: Number of static continuous features It also contains the following information: * ``max_encoder_length``: maximum encoder length * ``max_prediction_length``: maximum prediction length * ``min_encoder_length``: minimum encoder length * ``min_prediction_length``: minimum prediction length """ if self._metadata is None: self._metadata = self._prepare_metadata() return self._metadata def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. Preprocessing steps -------------------- * Converts target (`y`) and features (`x`) to `torch.float32`. * Masks time points that are at or before the cutoff time. * Splits features into categorical and continuous subsets based on predefined indices. TODO: add scalers, target normalizers etc. """ sample = self.time_series_dataset[series_idx] target = sample["y"] features = sample["x"] times = sample["t"] cutoff_time = sample["cutoff_time"] time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) if isinstance(target, torch.Tensor): target = target.float() else: target = torch.tensor(target, dtype=torch.float32) if isinstance(features, torch.Tensor): features = features.float() else: features = torch.tensor(features, dtype=torch.float32) # TODO: add scalers, target normalizers etc. categorical = ( features[:, self.categorical_indices] if self.categorical_indices else torch.zeros((features.shape[0], 0)) ) continuous = ( features[:, self.continuous_indices] if self.continuous_indices else torch.zeros((features.shape[0], 0)) ) return { "features": {"categorical": categorical, "continuous": continuous}, "target": target, "static": sample.get("st", None), "group": sample.get("group", torch.tensor([0])), "length": len(target), "time_mask": time_mask, "times": times, "cutoff_time": cutoff_time, } class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. Parameters ---------- dataset : TimeSeries The base time series dataset that provides access to raw data and metadata. data_module : EncoderDecoderTimeSeriesDataModule The data module handling preprocessing and metadata configuration. windows : List[Tuple[int, int, int, int]] List of window tuples containing (series_idx, start_idx, enc_length, pred_length). add_relative_time_idx : bool, default=False Whether to include relative time indices. """ def __init__( self, dataset: TimeSeries, data_module: "EncoderDecoderTimeSeriesDataModule", windows: list[tuple[int, int, int, int]], add_relative_time_idx: bool = False, ): self.dataset = dataset self.data_module = data_module self.windows = windows self.add_relative_time_idx = add_relative_time_idx def __len__(self): return len(self.windows) def __getitem__(self, idx): """Retrieve a processed time series window for dataloader input. Parameters ---------- idx : int Index of the window to retrieve from the dataset. Returns ------- x : dict Dictionary containing model inputs: * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) Categorical features for the encoder. * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) Continuous features for the encoder. * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) Categorical features for the decoder. * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) Continuous features for the decoder. * ``encoder_lengths`` : tensor of shape (1,) Length of the encoder sequence. * ``decoder_lengths`` : tensor of shape (1,) Length of the decoder sequence. * ``decoder_target_lengths`` : tensor of shape (1,) Length of the decoder target sequence. * ``groups`` : tensor of shape (1,) Group identifier for the time series instance. * ``encoder_time_idx`` : tensor of shape (enc_length,) Time indices for the encoder sequence. * ``decoder_time_idx`` : tensor of shape (pred_length,) Time indices for the decoder sequence. * ``target_past`` : torch.Tensor of shape (enc_length,) Historical target values for the encoder sequence. * ``target_scale`` : tensor of shape (1,) Scaling factor for the target values. * ``encoder_mask`` : tensor of shape (enc_length,) Boolean mask indicating valid encoder time points. * ``decoder_mask`` : tensor of shape (pred_length,) Boolean mask indicating valid decoder time points. If static features are present, the following keys are added: * ``static_categorical_features`` : tensor of shape (1, n_static_cat_features), optional Static categorical features, if available. * ``static_continuous_features`` : tensor of shape (1, 0), optional Placeholder for static continuous features (currently empty). y : torch.Tensor or list of torch.Tensor Target values for the decoder sequence. If ``n_targets`` > 1, a list of tensors each of shape (pred_length,) is returned. Otherwise, a tensor of shape (pred_length,) is returned. """ series_idx, start_idx, enc_length, pred_length = self.windows[idx] data = self.data_module._preprocess_data(series_idx) end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) decoder_indices = slice(start_idx + enc_length, end_idx) target_past = data["target"][encoder_indices] target_scale = target_past[~torch.isnan(target_past)].abs().mean() if torch.isnan(target_scale) or target_scale == 0: target_scale = torch.tensor(1.0) encoder_mask = ( data["time_mask"][encoder_indices] if "time_mask" in data else torch.ones(enc_length, dtype=torch.bool) ) decoder_mask = ( data["time_mask"][decoder_indices] if "time_mask" in data else torch.zeros(pred_length, dtype=torch.bool) ) encoder_cat = data["features"]["categorical"][encoder_indices] encoder_cont = data["features"]["continuous"][encoder_indices] features = data["features"] metadata = self.data_module.time_series_metadata known_cat_indices = [ i for i, col in enumerate(metadata["cols"]["x"]) if metadata["col_type"].get(col) == "C" and metadata["col_known"].get(col) == "K" ] known_cont_indices = [ i for i, col in enumerate(metadata["cols"]["x"]) if metadata["col_type"].get(col) == "F" and metadata["col_known"].get(col) == "K" ] cat_map = { orig_idx: i for i, orig_idx in enumerate(self.data_module.categorical_indices) } cont_map = { orig_idx: i for i, orig_idx in enumerate(self.data_module.continuous_indices) } mapped_known_cat_indices = [ cat_map[idx] for idx in known_cat_indices if idx in cat_map ] mapped_known_cont_indices = [ cont_map[idx] for idx in known_cont_indices if idx in cont_map ] decoder_cat = ( features["categorical"][decoder_indices][:, mapped_known_cat_indices] if mapped_known_cat_indices else torch.zeros((pred_length, 0)) ) decoder_cont = ( features["continuous"][decoder_indices][:, mapped_known_cont_indices] if mapped_known_cont_indices else torch.zeros((pred_length, 0)) ) x = { "encoder_cat": encoder_cat, "encoder_cont": encoder_cont, "decoder_cat": decoder_cat, "decoder_cont": decoder_cont, "encoder_lengths": torch.tensor(enc_length), "decoder_lengths": torch.tensor(pred_length), "decoder_target_lengths": torch.tensor(pred_length), "groups": data["group"], "target_past": target_past, "encoder_time_idx": torch.arange(enc_length), "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), "target_scale": target_scale, "encoder_mask": encoder_mask, "decoder_mask": decoder_mask, } if data["static"] is not None: raw_st_tensor = data.get("static") static_col_names = self.data_module.time_series_metadata["cols"]["st"] is_categorical_mask = torch.tensor( [ self.data_module.time_series_metadata["col_type"].get(col_name) == "C" for col_name in static_col_names ], dtype=torch.bool, ) is_continuous_mask = ~is_categorical_mask st_cat_values_for_item = raw_st_tensor[is_categorical_mask] st_cont_values_for_item = raw_st_tensor[is_continuous_mask] if st_cat_values_for_item.shape[0] > 0: x["static_categorical_features"] = st_cat_values_for_item.unsqueeze( 0 ) else: x["static_categorical_features"] = torch.zeros( (1, 0), dtype=torch.float32 ) if st_cont_values_for_item.shape[0] > 0: x["static_continuous_features"] = st_cont_values_for_item.unsqueeze( 0 ) else: x["static_continuous_features"] = torch.zeros( (1, 0), dtype=torch.float32 ) y = data["target"][decoder_indices] if self.data_module.n_targets > 1: y = [t.squeeze(-1) for t in torch.split(y, 1, dim=1)] else: y = y.squeeze(-1) return x, y def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, int]]: """Generate sliding windows for training, validation, and testing. Returns ------- List[Tuple[int, int, int, int]] A list of tuples, where each tuple consists of: - ``series_idx`` : int Index of the time series in `time_series_dataset`. - ``start_idx`` : int Start index of the encoder window. - ``enc_length`` : int Length of the encoder input sequence. - ``pred_length`` : int Length of the decoder output sequence. """ windows = [] for idx in indices: series_idx = idx.item() sample = self.time_series_dataset[series_idx] sequence_length = len(sample["y"]) if sequence_length < self.max_encoder_length + self.max_prediction_length: continue effective_min_prediction_idx = ( self.min_prediction_idx if self.min_prediction_idx is not None else self.max_encoder_length ) max_prediction_idx = sequence_length - self.max_prediction_length + 1 if max_prediction_idx <= effective_min_prediction_idx: continue for start_idx in range( 0, max_prediction_idx - effective_min_prediction_idx ): if ( start_idx + self.max_encoder_length + self.max_prediction_length <= sequence_length ): windows.append( ( series_idx, start_idx, self.max_encoder_length, self.max_prediction_length, ) ) return windows
[docs] def setup(self, stage: str | None = None): """Prepare the datasets for training, validation, testing, or prediction. Parameters ---------- stage : Optional[str], default=None Specifies the stage of setup. Can be one of: - ``"fit"`` : Prepares training and validation datasets. - ``"test"`` : Prepares the test dataset. - ``"predict"`` : Prepares the dataset for inference. - ``None`` : Prepares ``fit`` datasets. """ total_series = len(self.time_series_dataset) self._split_indices = torch.randperm(total_series) self._train_size = int(self.train_val_test_split[0] * total_series) self._val_size = int(self.train_val_test_split[1] * total_series) self._train_indices = self._split_indices[: self._train_size] self._val_indices = self._split_indices[ self._train_size : self._train_size + self._val_size ] self._test_indices = self._split_indices[self._train_size + self._val_size :] if stage is None or stage == "fit": if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): self.train_windows = self._create_windows(self._train_indices) self.val_windows = self._create_windows(self._val_indices) self.train_dataset = self._ProcessedEncoderDecoderDataset( self.time_series_dataset, self, self.train_windows, self.add_relative_time_idx, ) self.val_dataset = self._ProcessedEncoderDecoderDataset( self.time_series_dataset, self, self.val_windows, self.add_relative_time_idx, ) elif stage == "test": if not hasattr(self, "test_dataset"): self.test_windows = self._create_windows(self._test_indices) self.test_dataset = self._ProcessedEncoderDecoderDataset( self.time_series_dataset, self, self.test_windows, self.add_relative_time_idx, ) elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) self.predict_windows = self._create_windows(predict_indices) self.predict_dataset = self._ProcessedEncoderDecoderDataset( self.time_series_dataset, self, self.predict_windows, self.add_relative_time_idx, )
[docs] def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=self.collate_fn, )
[docs] def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn, )
[docs] def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn, )
[docs] def predict_dataloader(self): return DataLoader( self.predict_dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn, )
@staticmethod def collate_fn(batch): x_batch = { "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), "decoder_target_lengths": torch.stack( [x["decoder_target_lengths"] for x, _ in batch] ), "groups": torch.stack([x["groups"] for x, _ in batch]), "target_past": torch.stack([x["target_past"] for x, _ in batch]), "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), } if "static_categorical_features" in batch[0][0]: x_batch["static_categorical_features"] = torch.stack( [x["static_categorical_features"] for x, _ in batch] ) x_batch["static_continuous_features"] = torch.stack( [x["static_continuous_features"] for x, _ in batch] ) if isinstance(batch[0][1], list | tuple): num_targets = len(batch[0][1]) y_batch = [] for i in range(num_targets): target_tensors = [sample_y[i] for _, sample_y in batch] stacked_target = torch.stack(target_tensors) y_batch.append(stacked_target) else: y_batch = torch.stack([y for _, y in batch]) return x_batch, y_batch