GRU#
- class pytorch_forecasting.models.nn.rnn.GRU(input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, device=None, dtype=None)[source]#
- class pytorch_forecasting.models.nn.rnn.GRU(*args, **kwargs)
Bases:
RNN,GRUGRU that can handle zero-length sequences
Methods
handle_no_encoding(hidden_state, ...)Mask the hidden_state where there is no encoding.
Initialise a hidden_state.
repeat_interleave(hidden_state, n_samples)Duplicate the hidden_state n_samples times.
- handle_no_encoding(hidden_state: tuple[Tensor, Tensor] | Tensor, no_encoding: BoolTensor, initial_hidden_state: tuple[Tensor, Tensor] | Tensor) tuple[Tensor, Tensor] | Tensor[source]#
Mask the hidden_state where there is no encoding.
- Parameters:
hidden_state (HiddenState) – Hidden state where some entries need replacement.
no_encoding (torch.BoolTensor) – Positions that need replacement.
initial_hidden_state (HiddenState) – Hidden state to use for replacement.
- Returns:
Hidden state with propagated initial hidden state where appropriate.
- Return type:
HiddenState
Initialise a hidden_state.
- Parameters:
x (torch.Tensor) – Network input.
- Returns:
Default (zero-like) hidden state.
- Return type:
HiddenState
- repeat_interleave(hidden_state: tuple[Tensor, Tensor] | Tensor, n_samples: int) tuple[Tensor, Tensor] | Tensor[source]#
Duplicate the hidden_state n_samples times.
- Parameters:
hidden_state (HiddenState) – Hidden state to repeat.
n_samples (int) – Number of repetitions.
- Returns:
Repeated hidden state.
- Return type:
HiddenState