get_rnn

pytorch_forecasting.models.nn.rnn.get_rnn(cell_type: Union[Type[pytorch_forecasting.models.nn.rnn.RNN], str]) Type[pytorch_forecasting.models.nn.rnn.RNN][source]

Get LSTM or GRU.

Parameters

cell_type (Union[RNN, str]) – “LSTM” or “GRU”

Returns

returns GRU or LSTM RNN module

Return type

Type[RNN]