LSTM¶
- class pytorch_forecasting.models.nn.rnn.LSTM(*args, **kwargs)[source]¶
Bases:
pytorch_forecasting.models.nn.rnn.RNN
,torch.nn.modules.rnn.LSTM
LSTM that can handle zero-length sequences
Methods
handle_no_encoding
(hidden_state, ...)Mask the hidden_state where there is no encoding.
Initialise a hidden_state.
repeat_interleave
(hidden_state, n_samples)Duplicate the hidden_state n_samples times.
- handle_no_encoding(hidden_state: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], no_encoding: torch.BoolTensor, initial_hidden_state: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]) Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] [source]¶
Mask the hidden_state where there is no encoding.
- Parameters
hidden_state (HiddenState) – hidden state where some entries need replacement
no_encoding (torch.BoolTensor) – positions that need replacement
initial_hidden_state (HiddenState) – hidden state to use for replacement
- Returns
hidden state with propagated initial hidden state where appropriate
- Return type
HiddenState
Initialise a hidden_state.
- Parameters
x (torch.Tensor) – network input
- Returns
default (zero-like) hidden state
- Return type
HiddenState
- repeat_interleave(hidden_state: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], n_samples: int) Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] [source]¶
Duplicate the hidden_state n_samples times.
- Parameters
hidden_state (HiddenState) – hidden state to repeat
n_samples (int) – number of repetitions
- Returns
repeated hidden state
- Return type
HiddenState