TimeSeriesDataSet#

class pytorch_forecasting.data.timeseries.TimeSeriesDataSet(data: DataFrame, time_idx: str, target: str | List[str], group_ids: List[str], weight: str | None = None, max_encoder_length: int = 30, min_encoder_length: int | None = None, min_prediction_idx: int | None = None, min_prediction_length: int | None = None, max_prediction_length: int = 1, static_categoricals: List[str] = [], static_reals: List[str] = [], time_varying_known_categoricals: List[str] = [], time_varying_known_reals: List[str] = [], time_varying_unknown_categoricals: List[str] = [], time_varying_unknown_reals: List[str] = [], variable_groups: Dict[str, List[int]] = {}, constant_fill_strategy: Dict[str, str | float | int | bool] = {}, allow_missing_timesteps: bool = False, lags: Dict[str, List[int]] = {}, add_relative_time_idx: bool = False, add_target_scales: bool = False, add_encoder_length: bool | str = 'auto', target_normalizer: TorchNormalizer | NaNLabelEncoder | EncoderNormalizer | str | List[TorchNormalizer | NaNLabelEncoder | EncoderNormalizer] | Tuple[TorchNormalizer | NaNLabelEncoder | EncoderNormalizer] = 'auto', categorical_encoders: Dict[str, NaNLabelEncoder] = {}, scalers: Dict[str, StandardScaler | RobustScaler | TorchNormalizer | EncoderNormalizer] = {}, randomize_length: None | Tuple[float, float] | bool = False, predict_mode: bool = False)[source]#

Bases: Dataset

PyTorch Dataset for fitting timeseries models.

The dataset automates common tasks such as

  • scaling and encoding of variables

  • normalizing the target variable

  • efficiently converting timeseries in pandas dataframes to torch tensors

  • holding information about static and time-varying variables known and unknown in the future

  • holiding information about related categories (such as holidays)

  • downsampling for data augmentation

  • generating inference, validation and test datasets

  • etc.

Timeseries dataset holding data for models.

The tutorial on passing data to models is helpful to understand the output of the dataset and how it is coupled to models.

Each sample is a subsequence of a full time series. The subsequence consists of encoder and decoder/prediction timepoints for a given time series. This class constructs an index which defined which subsequences exists and can be samples from (index attribute). The samples in the index are defined by by the various parameters. to the class (encoder and prediction lengths, minimum prediction length, randomize length and predict keywords). How samples are sampled into batches for training, is determined by the DataLoader. The class provides the to_dataloader() method to convert the dataset into a dataloader.

Large datasets:

Currently the class is limited to in-memory operations (that can be sped up by an existing installation of numba). If you have extremely large data, however, you can pass prefitted encoders and and scalers to it and a subset of sequences to the class to construct a valid dataset (plus, likely the EncoderNormalizer should be used to normalize targets). when fitting a network, you would then to create a custom DataLoader that rotates through the datasets. There is currently no in-built methods to do this.

Parameters:
  • data (pd.DataFrame) – dataframe with sequence data - each row can be identified with time_idx and the group_ids

  • time_idx (str) – integer column denoting the time index. This columns is used to determine the sequence of samples. If there no missings observations, the time index should increase by +1 for each subsequent sample. The first time_idx for each series does not necessarily have to be 0 but any value is allowed.

  • target (Union[str, List[str]]) – column denoting the target or list of columns denoting the target - categorical or continous.

  • group_ids (List[str]) – list of column names identifying a time series. This means that the group_ids identify a sample together with the time_idx. If you have only one timeseries, set this to the name of column that is constant.

  • weight (str) – column name for weights. Defaults to None.

  • max_encoder_length (int) – maximum length to encode. This is the maximum history length used by the time series dataset.

  • min_encoder_length (int) – minimum allowed length to encode. Defaults to max_encoder_length.

  • min_prediction_idx (int) – minimum time_idx from where to start predictions. This parameter can be useful to create a validation or test set.

  • max_prediction_length (int) – maximum prediction/decoder length (choose this not too short as it can help convergence)

  • min_prediction_length (int) – minimum prediction/decoder length. Defaults to max_prediction_length

  • static_categoricals (List[str]) – list of categorical variables that do not change over time, entries can be also lists which are then encoded together (e.g. useful for product categories)

  • static_reals (List[str]) – list of continuous variables that do not change over time

  • time_varying_known_categoricals (List[str]) – list of categorical variables that change over time and are known in the future, entries can be also lists which are then encoded together (e.g. useful for special days or promotion categories)

  • time_varying_known_reals (List[str]) – list of continuous variables that change over time and are known in the future (e.g. price of a product, but not demand of a product)

  • time_varying_unknown_categoricals (List[str]) – list of categorical variables that change over time and are not known in the future, entries can be also lists which are then encoded together (e.g. useful for weather categories). You might want to include your target here.

  • time_varying_unknown_reals (List[str]) – list of continuous variables that change over time and are not known in the future. You might want to include your target here.

  • variable_groups (Dict[str, List[str]]) – dictionary mapping a name to a list of columns in the data. The name should be present in a categorical or real class argument, to be able to encode or scale the columns by group. This will effectively combine categorical variables is particularly useful if a categorical variable can have multiple values at the same time. An example are holidays which can be overlapping.

  • constant_fill_strategy (Dict[str, Union[str, float, int, bool]]) – dictionary of column names with constants to fill in missing values if there are gaps in the sequence (by default forward fill strategy is used). The values will be only used if allow_missing_timesteps=True. A common use case is to denote that demand was 0 if the sample is not in the dataset.

  • allow_missing_timesteps (bool) – if to allow missing timesteps that are automatically filled up. Missing values refer to gaps in the time_idx, e.g. if a specific timeseries has only samples for 1, 2, 4, 5, the sample for 3 will be generated on-the-fly. Allow missings does not deal with NA values. You should fill NA values before passing the dataframe to the TimeSeriesDataSet.

  • lags (Dict[str, List[int]]) – dictionary of variable names mapped to list of time steps by which the variable should be lagged. Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, add at least the target variables with the corresponding lags to improve performance. Lags must be at not larger than the shortest time series as all time series will be cut by the largest lag value to prevent NA values. A lagged variable has to appear in the time-varying variables. If you only want the lagged but not the current value, lag it manually in your input data using data[lagged_variable_name] = data.sort_values(time_idx).groupby(group_ids, observed=True).shift(lag) . Defaults to no lags.

  • add_relative_time_idx (bool) – if to add a relative time index as feature (i.e. for each sampled sequence, the index will range from -encoder_length to prediction_length)

  • add_target_scales (bool) – if to add scales for target to static real features (i.e. add the center and scale of the unnormalized timeseries as features)

  • add_encoder_length (bool) – if to add decoder length to list of static real variables. Defaults to “auto”, i.e. True if min_encoder_length != max_encoder_length.

  • target_normalizer (Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer, str, list, tuple]) – transformer that take group_ids, target and time_idx to normalize targets. You can choose from TorchNormalizer, GroupNormalizer, NaNLabelEncoder, EncoderNormalizer (on which overfitting tests will fail) or None for using no normalizer. For multiple targets, use a :py:class`~pytorch_forecasting.data.encoders.MultiNormalizer`. By default an appropriate normalizer is chosen automatically.

  • categorical_encoders (Dict[str, NaNLabelEncoder]) – dictionary of scikit learn label transformers. If you have unobserved categories in the future / a cold-start problem, you can use the NaNLabelEncoder with add_nan=True. Defaults effectively to sklearn’s LabelEncoder(). Prefittet encoders will not be fit again.

  • scalers (Dict[str, Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer]]) – dictionary of scikit-learn scalers. Defaults to sklearn’s StandardScaler(). Other options are EncoderNormalizer, GroupNormalizer or scikit-learn’s StandarScaler(), RobustScaler() or None for using no normalizer / normalizer with center=0 and scale=1 (method=”identity”). Prefittet encoders will not be fit again (with the exception of the EncoderNormalizer that is fit on every encoder sequence).

  • randomize_length (Union[None, Tuple[float, float], bool]) – None or False if not to randomize lengths. Tuple of beta distribution concentrations from which probabilities are sampled that are used to sample new sequence lengths with a binomial distribution. If True, defaults to (0.2, 0.05), i.e. ~1/4 of samples around minimum encoder length. Defaults to False otherwise.

  • predict_mode (bool) – if to only iterate over each timeseries once (only the last provided samples). Effectively, this will take choose for each time series identified by group_ids the last max_prediction_length samples of each time series as prediction samples and everthing previous up to max_encoder_length samples as encoder samples.

Inherited-members:

Methods

calculate_decoder_length(time_last, ...)

Calculate length of decoder.

filter(filter_func[, copy])

Filter subsequences in dataset.

from_dataset(dataset, data[, ...])

Generate dataset with different underlying data but same variable encoders and scalers, etc.

from_parameters(parameters, data[, ...])

Generate dataset with different underlying data but same variable encoders and scalers, etc.

get_parameters()

Get parameters that can be used with from_parameters() to create a new dataset with the same scalers.

get_transformer(name[, group_id])

Get transformer for variable.

load(fname)

Load dataset from disk

plot_randomization([betas, length, min_length])

Plot expected randomized length distribution.

reset_overwrite_values()

Reset values used to override sample features.

save(fname)

Save dataset to disk

set_overwrite_values(values, variable[, target])

Convenience method to quickly overwrite values in decoder or encoder (or both) for a specific variable.

to_dataloader([train, batch_size, batch_sampler])

Get dataloader from dataset.

transform_values(name, values[, data, ...])

Scale and encode values.

x_to_index(x)

Decode dataframe index from x.

Attributes

categoricals

Categorical variables as used for modelling.

decoded_index

Get interpretable version of index.

dropout_categoricals

list of categorical variables that are unknown when making a forecast without observed history

flat_categoricals

Categorical variables as defined in input data.

lagged_targets

Subset of lagged_variables but only includes variables that are lagged targets.

lagged_variables

Lagged variables.

max_lag

Maximum number of time steps variables are lagged.

min_lag

Minimum number of time steps variables are lagged.

multi_target

If dataset encodes one or multiple targets.

reals

Continous variables as used for modelling.

target_names

List of targets.

target_normalizers

List of target normalizers aligned with target_names.

variable_to_group_mapping

Mapping from categorical variables to variables in input data.

calculate_decoder_length(time_last: int | Series | ndarray, sequence_length: int | Series | ndarray) int | Series | ndarray[source]#

Calculate length of decoder.

Parameters:
  • time_last (Union[int, pd.Series, np.ndarray]) – last time index of the sequence

  • sequence_length (Union[int, pd.Series, np.ndarray]) – total length of the sequence

Returns:

decoder length(s)

Return type:

Union[int, pd.Series, np.ndarray]

filter(filter_func: Callable, copy: bool = True) TimeSeriesDataSet[source]#

Filter subsequences in dataset.

Uses interpretable version of index decoded_index() to filter subsequences in dataset.

Parameters:
  • filter_func (Callable) – function to filter. Should take decoded_index() dataframe as only argument which contains group ids and time index columns.

  • copy (bool) – if to return copy of dataset or filter inplace.

Returns:

filtered dataset

Return type:

TimeSeriesDataSet

classmethod from_dataset(dataset, data: DataFrame, stop_randomization: bool = False, predict: bool = False, **update_kwargs)[source]#

Generate dataset with different underlying data but same variable encoders and scalers, etc.

Calls from_parameters() under the hood.

Parameters:
  • dataset (TimeSeriesDataSet) – dataset from which to copy parameters

  • data (pd.DataFrame) – data from which new dataset will be generated

  • stop_randomization (bool, optional) – If to stop randomizing encoder and decoder lengths, e.g. useful for validation set. Defaults to False.

  • predict (bool, optional) – If to predict the decoder length on the last entries in the time index (i.e. one prediction per group only). Defaults to False.

  • **kwargs – keyword arguments overriding parameters in the original dataset

Returns:

new dataset

Return type:

TimeSeriesDataSet

classmethod from_parameters(parameters: Dict[str, Any], data: DataFrame, stop_randomization: bool | None = None, predict: bool = False, **update_kwargs)[source]#

Generate dataset with different underlying data but same variable encoders and scalers, etc.

Parameters:
  • parameters (Dict[str, Any]) – dataset parameters which to use for the new dataset

  • data (pd.DataFrame) – data from which new dataset will be generated

  • stop_randomization (bool, optional) – If to stop randomizing encoder and decoder lengths, e.g. useful for validation set. Defaults to False.

  • predict (bool, optional) – If to predict the decoder length on the last entries in the time index (i.e. one prediction per group only). Defaults to False.

  • **kwargs – keyword arguments overriding parameters

Returns:

new dataset

Return type:

TimeSeriesDataSet

get_parameters() Dict[str, Any][source]#

Get parameters that can be used with from_parameters() to create a new dataset with the same scalers.

Returns:

dictionary of parameters

Return type:

Dict[str, Any]

get_transformer(name: str, group_id: bool = False)[source]#

Get transformer for variable.

Parameters:
  • name (str) – variable name

  • group_id (bool, optional) – If the passed name refers to a group id (different encoders are used for these). Defaults to False.

Returns:

transformer

classmethod load(fname: str)[source]#

Load dataset from disk

Parameters:

fname (str) – filename to load from

Returns:

TimeSeriesDataSet

plot_randomization(betas: Tuple[float, float] | None = None, length: int | None = None, min_length: int | None = None) Tuple[Figure, Tensor][source]#

Plot expected randomized length distribution.

Parameters:
  • betas (Tuple[float, float], optional) – Tuple of betas, e.g. (0.2, 0.05) to use for randomization. Defaults to randomize_length of dataset.

  • length (int, optional) – . Defaults to max_encoder_length.

  • min_length (int, optional) – [description]. Defaults to min_encoder_length.

Returns:

tuple of figure and histogram based on 1000 samples

Return type:

Tuple[plt.Figure, torch.Tensor]

reset_overwrite_values() None[source]#

Reset values used to override sample features.

save(fname: str) None[source]#

Save dataset to disk

Parameters:

fname (str) – filename to save to

set_overwrite_values(values: float | Tensor, variable: str, target: str | slice = 'decoder') None[source]#

Convenience method to quickly overwrite values in decoder or encoder (or both) for a specific variable.

Parameters:
  • values (Union[float, torch.Tensor]) – values to use for overwrite.

  • variable (str) – variable whose values should be overwritten.

  • target (Union[str, slice], optional) – positions to overwrite. One of “decoder”, “encoder” or “all” or a slice object which is directly used to overwrite indices, e.g. slice(-5, None) will overwrite the last 5 values. Defaults to “decoder”.

to_dataloader(train: bool = True, batch_size: int = 64, batch_sampler: Sampler | str | None = None, **kwargs) DataLoader[source]#

Get dataloader from dataset.

The

Parameters:
  • train (bool, optional) – if dataloader is used for training or prediction Will shuffle and drop last batch if True. Defaults to True.

  • batch_size (int) – batch size for training model. Defaults to 64.

  • batch_sampler (Union[Sampler, str]) –

    batch sampler or string. One of

    • ”synchronized”: ensure that samples in decoder are aligned in time. Does not support missing values in dataset. This makes only sense if the underlying algorithm makes use of values aligned in time.

    • PyTorch Sampler instance: any PyTorch sampler, e.g. the WeightedRandomSampler()

    • None: samples are taken randomly from times series.

  • **kwargs – additional arguments to DataLoader()

Returns:

dataloader that returns Tuple.

First entry is x, a dictionary of tensors with the entries (and shapes in brackets)

  • encoder_cat (batch_size x n_encoder_time_steps x n_features): long tensor of encoded categoricals for encoder

  • encoder_cont (batch_size x n_encoder_time_steps x n_features): float tensor of scaled continuous variables for encoder

  • encoder_target (batch_size x n_encoder_time_steps or list thereof with each entry for a different target): float tensor with unscaled continous target or encoded categorical target, list of tensors for multiple targets

  • encoder_lengths (batch_size): long tensor with lengths of the encoder time series. No entry will be greater than n_encoder_time_steps

  • decoder_cat (batch_size x n_decoder_time_steps x n_features): long tensor of encoded categoricals for decoder

  • decoder_cont (batch_size x n_decoder_time_steps x n_features): float tensor of scaled continuous variables for decoder

  • decoder_target (batch_size x n_decoder_time_steps or list thereof with each entry for a different target): float tensor with unscaled continous target or encoded categorical target for decoder - this corresponds to first entry of y, list of tensors for multiple targets

  • decoder_lengths (batch_size): long tensor with lengths of the decoder time series. No entry will be greater than n_decoder_time_steps

  • group_ids (batch_size x number_of_ids): encoded group ids that identify a time series in the dataset

  • target_scale (batch_size x scale_size or list thereof with each entry for a different target): parameters used to normalize the target. Typically these are mean and standard deviation. Is list of tensors for multiple targets.

Second entry is y, a tuple of the form (target, weight)

  • target (batch_size x n_decoder_time_steps or list thereof with each entry for a different target): unscaled (continuous) or encoded (categories) targets, list of tensors for multiple targets

  • weight (None or batch_size x n_decoder_time_steps): weight

Return type:

DataLoader

Example

Weight by samples for training:

from torch.utils.data import WeightedRandomSampler

# length of probabilties for sampler have to be equal to the length of the index
probabilities = np.sqrt(1 + data.loc[dataset.index, "target"])
sampler = WeightedRandomSampler(probabilities, len(probabilities))
dataset.to_dataloader(train=True, sampler=sampler, shuffle=False)
transform_values(name: str, values: Series | Tensor | ndarray, data: DataFrame | None = None, inverse=False, group_id: bool = False, **kwargs) ndarray[source]#

Scale and encode values.

Parameters:
  • name (str) – name of variable

  • values (Union[pd.Series, torch.Tensor, np.ndarray]) – values to encode/scale

  • data (pd.DataFrame, optional) – extra data used for scaling (e.g. dataframe with groups columns). Defaults to None.

  • inverse (bool, optional) – if to conduct inverse transformation. Defaults to False.

  • group_id (bool, optional) – If the passed name refers to a group id (different encoders are used for these). Defaults to False.

  • **kwargs – additional arguments for transform/inverse_transform method

Returns:

(de/en)coded/(de)scaled values

Return type:

np.ndarray

x_to_index(x: Dict[str, Tensor]) DataFrame[source]#

Decode dataframe index from x.

Returns:

dataframe with time index column for first prediction and group ids

property categoricals: List[str]#

Categorical variables as used for modelling.

Returns:

list of variables

Return type:

List[str]

property decoded_index: DataFrame#

Get interpretable version of index.

DataFrame contains - group_id columns in original encoding - time_idx_first column: first time index of subsequence - time_idx_last columns: last time index of subsequence - time_idx_first_prediction columns: first time index which is in decoder

Returns:

index that can be understood in terms of original data

Return type:

pd.DataFrame

property dropout_categoricals: List[str]#

list of categorical variables that are unknown when making a forecast without observed history

property flat_categoricals: List[str]#

Categorical variables as defined in input data.

Returns:

list of variables

Return type:

List[str]

property lagged_targets: Dict[str, str]#

Subset of lagged_variables but only includes variables that are lagged targets.

property lagged_variables: Dict[str, str]#

Lagged variables.

Returns:

dictionary of variable names corresponding to lagged variables

mapped to variable that is lagged

Return type:

Dict[str, str]

property max_lag: int#

Maximum number of time steps variables are lagged.

Returns:

maximum lag

Return type:

int

property min_lag: int#

Minimum number of time steps variables are lagged.

Returns:

minimum lag

Return type:

int

property multi_target: bool#

If dataset encodes one or multiple targets.

Returns:

true if multiple targets

Return type:

bool

property reals: List[str]#

Continous variables as used for modelling.

Returns:

list of variables

Return type:

List[str]

property target_names: List[str]#

List of targets.

Returns:

list of targets

Return type:

List[str]

property target_normalizers: List[TorchNormalizer]#

List of target normalizers aligned with target_names.

Returns:

list of target normalizers

Return type:

List[TorchNormalizer]

property variable_to_group_mapping: Dict[str, str]#

Mapping from categorical variables to variables in input data.

Returns:

dictionary mapping from categorical() to flat_categoricals().

Return type:

Dict[str, str]