GRU

class pytorch_forecasting.models.nn.rnn.GRU(*args, **kwargs)[source]

Bases: pytorch_forecasting.models.nn.rnn.RNN, torch.nn.modules.rnn.GRU

GRU that can handle zero-length sequences

Methods

handle_no_encoding(hidden_state, ...)

Mask the hidden_state where there is no encoding.

init_hidden_state(x)

Initialise a hidden_state.

repeat_interleave(hidden_state, n_samples)

Duplicate the hidden_state n_samples times.

handle_no_encoding(hidden_state: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], no_encoding: torch.BoolTensor, initial_hidden_state: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]) Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor][source]

Mask the hidden_state where there is no encoding.

Parameters
  • hidden_state (HiddenState) – hidden state where some entries need replacement

  • no_encoding (torch.BoolTensor) – positions that need replacement

  • initial_hidden_state (HiddenState) – hidden state to use for replacement

Returns

hidden state with propagated initial hidden state where appropriate

Return type

HiddenState

init_hidden_state(x: torch.Tensor) Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor][source]

Initialise a hidden_state.

Parameters

x (torch.Tensor) – network input

Returns

default (zero-like) hidden state

Return type

HiddenState

repeat_interleave(hidden_state: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], n_samples: int) Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor][source]

Duplicate the hidden_state n_samples times.

Parameters
  • hidden_state (HiddenState) – hidden state to repeat

  • n_samples (int) – number of repetitions

Returns

repeated hidden state

Return type

HiddenState