pytorch_forecasting.data.samplers.GroupedSampler#

class pytorch_forecasting.data.samplers.GroupedSampler(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[source]#

Samples mini-batches randomly but in a grouped manner.

This means that the items from the different groups are always sampled together. This is an abstract class. Implement the get_groups() method which creates groups to be sampled from.

Initialize.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object.

  • drop_last (bool) – If True, drop last mini-batch from a group if it is smaller than batch_size. Default is False.

  • shuffle (bool) – If True, shuffle dataset. Default is False.

  • batch_size (int) – Number of samples in a mini-batch. This is rather the maximum number of samples. Because mini-batches are grouped by prediction time, chances are that there are multiple where batch size will be smaller than the maximum. Default is 64.

__init__(sampler: Sampler, batch_size: int = 64, shuffle: bool = False, drop_last: bool = False)[source]#

Initialize.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object.

  • drop_last (bool) – If True, drop last mini-batch from a group if it is smaller than batch_size. Default is False.

  • shuffle (bool) – If True, shuffle dataset. Default is False.

  • batch_size (int) – Number of samples in a mini-batch. This is rather the maximum number of samples. Because mini-batches are grouped by prediction time, chances are that there are multiple where batch size will be smaller than the maximum. Default is 64.

Methods

__class_getitem__

Parameterizes a generic class.

__delattr__(name, /)

Implement delattr(self, name).

__dir__()

Default dir() implementation.

__eq__(value, /)

Return self==value.

__format__(format_spec, /)

Default object formatter.

__ge__(value, /)

Return self>=value.

__getattribute__(name, /)

Return getattr(self, name).

__getstate__()

Helper for pickle.

__gt__(value, /)

Return self>value.

__hash__()

Return hash(self).

__init_subclass__

Function to initialize subclasses.

__iter__()

__le__(value, /)

Return self<=value.

__len__()

__lt__(value, /)

Return self<value.

__ne__(value, /)

Return self!=value.

__new__(*args, **kwargs)

__reduce__()

Helper for pickle.

__reduce_ex__(protocol, /)

Helper for pickle.

__repr__()

Return repr(self).

__setattr__(name, value, /)

Implement setattr(self, name, value).

__sizeof__()

Size of object in memory, in bytes.

__str__()

Return str(self).

__subclasshook__

Abstract classes can override this to customize issubclass().

construct_batch_groups(groups)

Construct index of batches from which can be sampled

get_groups(sampler)

Create the groups which can be sampled.

Attributes

__annotations__

__dict__

__doc__

__module__

__orig_bases__

__parameters__

__weakref__

list of weak references to the object