RNN#

class pytorch_forecasting.models.nn.rnn.RNN(mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, device=None, dtype=None)[source]#

Bases: ABC, RNNBase

Base class flexible RNNs.

Forward function can handle sequences of length 0.

Methods

forward(x[, hx, lengths, enforce_sorted])

Forward function of rnn that allows zero-length sequences.

handle_no_encoding(hidden_state, ...)

Mask the hidden_state where there is no encoding.

init_hidden_state(x)

Initialise a hidden_state.

repeat_interleave(hidden_state, n_samples)

Duplicate the hidden_state n_samples times.

forward(x: PackedSequence | Tensor, hx: tuple[Tensor, Tensor] | Tensor = None, lengths: LongTensor = None, enforce_sorted: bool = True) tuple[PackedSequence | Tensor, tuple[Tensor, Tensor] | Tensor][source]#

Forward function of rnn that allows zero-length sequences.

Functions as normal for RNN. Only changes output if lengths are defined.

Parameters:
  • x (rnn.PackedSequence or torch.Tensor) – Input to RNN. Either packed sequence or tensor of padded sequences.

  • hx (HiddenState, optional) – Hidden state. Defaults to None.

  • lengths (torch.LongTensor, optional) – Lengths of sequences. If not None, used to determine correct returned hidden state. Can contain zeros. Defaults to None.

  • enforce_sorted (bool, optional) – If lengths are passed, determines if RNN expects them to be sorted. Defaults to True.

Returns:

Output and hidden state. Output is a packed sequence if input was a packed sequence.

Return type:

tuple of (rnn.PackedSequence or torch.Tensor, HiddenState)

abstractmethod handle_no_encoding(hidden_state: tuple[Tensor, Tensor] | Tensor, no_encoding: BoolTensor, initial_hidden_state: tuple[Tensor, Tensor] | Tensor) tuple[Tensor, Tensor] | Tensor[source]#

Mask the hidden_state where there is no encoding.

Parameters:
  • hidden_state (HiddenState) – Hidden state where some entries need replacement.

  • no_encoding (torch.BoolTensor) – Positions that need replacement.

  • initial_hidden_state (HiddenState) – Hidden state to use for replacement.

Returns:

Hidden state with propagated initial hidden state where appropriate.

Return type:

HiddenState

abstractmethod init_hidden_state(x: Tensor) tuple[Tensor, Tensor] | Tensor[source]#

Initialise a hidden_state.

Parameters:

x (torch.Tensor) – Network input.

Returns:

Default (zero-like) hidden state.

Return type:

HiddenState

abstractmethod repeat_interleave(hidden_state: tuple[Tensor, Tensor] | Tensor, n_samples: int) tuple[Tensor, Tensor] | Tensor[source]#

Duplicate the hidden_state n_samples times.

Parameters:
  • hidden_state (HiddenState) – Hidden state to repeat.

  • n_samples (int) – Number of repetitions.

Returns:

Repeated hidden state.

Return type:

HiddenState