get_rnn#

pytorch_forecasting.models.nn.rnn.get_rnn(cell_type: type[RNN] | str) type[RNN][source]#

Get LSTM or GRU.

Parameters:

cell_type (type[RNN] or str) – RNN class or string identifier, either "LSTM" or "GRU".

Returns:

Returns the GRU or LSTM RNN class.

Return type:

type[RNN]