pytorch_forecasting.models.nn.rnn.
RNN
Bases: abc.ABC, torch.nn.modules.rnn.RNNBase
abc.ABC
torch.nn.modules.rnn.RNNBase
Base class flexible RNNs.
Forward function can handle sequences of length 0.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
forward(x[, hx, lengths, enforce_sorted])
forward
Forward function of rnn that allows zero-length sequences.
handle_no_encoding(hidden_state, …)
handle_no_encoding
Mask the hidden_state where there is no encoding.
init_hidden_state(x)
init_hidden_state
Initialise a hidden_state.
repeat_interleave(hidden_state, n_samples)
repeat_interleave
Duplicate the hidden_state n_samples times.
Functions as normal for RNN. Only changes output if lengths are defined.
x (Union[rnn.PackedSequence, torch.Tensor]) – input to RNN. either packed sequence or tensor of padded sequences
hx (HiddenState, optional) – hidden state. Defaults to None.
lengths (torch.LongTensor, optional) – lengths of sequences. If not None, used to determine correct returned hidden state. Can contain zeros. Defaults to None.
enforce_sorted (bool, optional) – if lengths are passed, determines if RNN expects them to be sorted. Defaults to True.
Output is packed sequence if input has been a packed sequence.
Tuple[Union[rnn.PackedSequence, torch.Tensor], HiddenState]
hidden_state (HiddenState) – hidden state where some entries need replacement
no_encoding (torch.BoolTensor) – positions that need replacement
initial_hidden_state (HiddenState) – hidden state to use for replacement
hidden state with propagated initial hidden state where appropriate
HiddenState
x (torch.Tensor) – network input
default (zero-like) hidden state
hidden_state (HiddenState) – hidden state to repeat
n_samples (int) – number of repetitions
repeated hidden state