GroupedSampler#

class pytorch_forecasting.data.samplers.GroupedSampler(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[source]#

Bases: Sampler

Samples mini-batches randomly but in a grouped manner.

This means that the items from the different groups are always sampled together. This is an abstract class. Implement the get_groups() method which creates groups to be sampled from.

Initialize.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • drop_last (bool) – if to drop last mini-batch from a group if it is smaller than batch_size. Defaults to False.

  • shuffle (bool) – if to shuffle dataset. Defaults to False.

  • batch_size (int, optional) – Number of samples in a mini-batch. This is rather the maximum number of samples. Because mini-batches are grouped by prediction time, chances are that there are multiple where batch size will be smaller than the maximum. Defaults to 64.

Inherited-members:

Methods

construct_batch_groups(groups)

Construct index of batches from which can be sampled

get_groups(sampler)

Create the groups which can be sampled.

construct_batch_groups(groups)[source]#

Construct index of batches from which can be sampled

get_groups(sampler: Sampler)[source]#

Create the groups which can be sampled.

Parameters:

sampler (Sampler) – will have attribute data_source which is of type TimeSeriesDataSet.

Returns:

dictionary-like object with data_source.index as values and group names as keys

Return type:

dict-like