AutoRegressiveBaseModel#
- class pytorch_forecasting.models.base_model.AutoRegressiveBaseModel(dataset_parameters: Dict[str, Any] = None, log_interval: int | float = -1, log_val_interval: int | float = None, learning_rate: float | List[float] = 0.001, log_gradient_flow: bool = False, loss: Metric = SMAPE(), logging_metrics: ModuleList = ModuleList(), reduce_on_plateau_patience: int = 1000, reduce_on_plateau_reduction: float = 2.0, reduce_on_plateau_min_lr: float = 1e-05, weight_decay: float = 0.0, optimizer_params: Dict[str, Any] = None, monotone_constaints: Dict[str, int] = None, output_transformer: Callable = None, optimizer=None)[source]#
Bases:
BaseModel
Model with additional methods for autoregressive models.
Adds in particular the
decode_autoregressive()
method for making auto-regressive predictions.Assumes the following hyperparameters:
- Parameters:
target (str) – name of target variable
target_lags (Dict[str, Dict[str, int]]) – dictionary of target names mapped each to a dictionary of corresponding lagged variables and their lags. Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, add at least the target variables with the corresponding lags to improve performance. Defaults to no lags, i.e. an empty dictionary.
BaseModel for timeseries forecasting from which to inherit from
- Parameters:
log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.
log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.
learning_rate (float, optional) – Learning rate. Defaults to 1e-3.
log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.
loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().
logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].
reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000
reduce_on_plateau_reduction (float) – reduction in learning rate when encountering plateau. Defaults to 2.0.
reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateua learning rate scheduler. Defaults to 1e-5
weight_decay (float) – weight decay. Defaults to 0.0.
optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.
monotone_constaints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g.
"0"
for first position) to constraint (-1
for negative and+1
for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to
lambda out: out["prediction"]
.optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in
torch.optim
orpytorch_optimizer
. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay). Defaults to “ranger”, if pytorch_optimizer is installed, otherwise “adam”.
Methods
decode_autoregressive
(decode_one, ...[, ...])Make predictions in auto-regressive manner.
from_dataset
(dataset, **kwargs)Create model from dataset.
output_to_prediction
(...[, n_samples])Convert network output to rescaled and normalized prediction.
plot_prediction
(x, out[, idx, ...])Plot prediction of prediction vs actuals
- decode_autoregressive(decode_one: Callable, first_target: List[Tensor] | Tensor, first_hidden_state: Any, target_scale: List[Tensor] | Tensor, n_decoder_steps: int, n_samples: int = 1, **kwargs) List[Tensor] | Tensor [source]#
Make predictions in auto-regressive manner.
Supports only continuous targets.
- Parameters:
decode_one (Callable) –
function that takes at least the following arguments:
idx
(int): index of decoding step (from 0 to n_decoder_steps-1)lagged_targets
(List[torch.Tensor]): list of normalized targets. List isidx + 1
elements long with the most recent entry at the end, i.e.previous_target = lagged_targets[-1]
and in generallagged_targets[-lag]
.hidden_state
(Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater thanidx
are included.additional arguments are not dynamic but can be passed via the
**kwargs
argument
And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step.
first_target (Union[List[torch.Tensor], torch.Tensor]) – first target value to use for decoding
first_hidden_state (Any) – first hidden state used for decoding
target_scale (Union[List[torch.Tensor], torch.Tensor]) – target scale as in
x
n_decoder_steps (int) – number of decoding/prediction steps
n_samples (int) – number of independent samples to draw from the distribution - only relevant for multivariate models. Defaults to 1.
**kwargs – additional arguments that are passed to the decode_one function.
- Returns:
re-scaled prediction
- Return type:
Union[List[torch.Tensor], torch.Tensor]
Example
LSTM/GRU decoder
def decode(self, x, hidden_state): # create input vector input_vector = x["decoder_cont"].clone() input_vector[..., self.target_positions] = torch.roll( input_vector[..., self.target_positions], shifts=1, dims=1, ) # but this time fill in missing target from encoder_cont at the first time step instead of # throwing it away last_encoder_target = x["encoder_cont"][ torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), x["encoder_lengths"] - 1, self.target_positions.unsqueeze(-1) ].T.contiguous() input_vector[:, 0, self.target_positions] = last_encoder_target if self.training: # training mode decoder_output, _ = self.rnn( x, hidden_state, lengths=x["decoder_lengths"], enforce_sorted=False, ) # from hidden state size to outputs if isinstance(self.hparams.target, str): # single target output = self.distribution_projector(decoder_output) else: output = [projector(decoder_output) for projector in self.distribution_projector] # predictions are not yet rescaled -> so rescale now return self.transform_output(output, target_scale=target_scale) else: # prediction mode target_pos = self.target_positions def decode_one(idx, lagged_targets, hidden_state): x = input_vector[:, [idx]] x[:, 0, target_pos] = lagged_targets[-1] # overwrite at target positions # overwrite at lagged targets positions for lag, lag_positions in lagged_target_positions.items(): if idx > lag: # only overwrite if target has been generated x[:, 0, lag_positions] = lagged_targets[-lag] decoder_output, hidden_state = self.rnn(x, hidden_state) decoder_output = decoder_output[:, 0] # take first timestep # from hidden state size to outputs if isinstance(self.hparams.target, str): # single target output = self.distribution_projector(decoder_output) else: output = [projector(decoder_output) for projector in self.distribution_projector] return output, hidden_state # make predictions which are fed into next step output = self.decode_autoregressive( decode_one, first_target=input_vector[:, 0, target_pos], first_hidden_state=hidden_state, target_scale=x["target_scale"], n_decoder_steps=input_vector.size(1), ) # predictions are already rescaled return output
- classmethod from_dataset(dataset: TimeSeriesDataSet, **kwargs) LightningModule [source]#
Create model from dataset.
- Parameters:
dataset – timeseries dataset
**kwargs – additional arguments such as hyperparameters for model (see
__init__()
)
- Returns:
LightningModule
- output_to_prediction(normalized_prediction_parameters: Tensor, target_scale: List[Tensor] | Tensor, n_samples: int = 1, **kwargs) Tuple[List[Tensor] | Tensor, Tensor] [source]#
Convert network output to rescaled and normalized prediction.
Function is typically not called directly but via
decode_autoregressive()
.- Parameters:
normalized_prediction_parameters (torch.Tensor) – network prediction output
target_scale (Union[List[torch.Tensor], torch.Tensor]) – target scale to rescale network output
n_samples (int, optional) – Number of samples to draw independently. Defaults to 1.
**kwargs – extra arguments for dictionary passed to
transform_output()
method.
- Returns:
- tuple of rescaled prediction and
normalized prediction (e.g. for input into next auto-regressive step)
- Return type:
Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]
- plot_prediction(x: Dict[str, Tensor], out: Dict[str, Tensor], idx: int = 0, add_loss_to_title: Metric | Tensor | bool = False, show_future_observed: bool = True, ax=None, quantiles_kwargs: Dict[str, Any] | None = None, prediction_kwargs: Dict[str, Any] | None = None)[source]#
Plot prediction of prediction vs actuals
- Parameters:
x – network input
out – network output
idx – index of prediction to plot
add_loss_to_title – if to add loss to title or loss function to calculate. Can be either metrics, bool indicating if to use loss metric or tensor which contains losses for all samples. Calcualted losses are determined without weights. Default to False.
show_future_observed – if to show actuals for future. Defaults to True.
ax – matplotlib axes to plot on
quantiles_kwargs (Dict[str, Any]) – parameters for
to_quantiles()
of the loss metric.prediction_kwargs (Dict[str, Any]) – parameters for
to_prediction()
of the loss metric.
- Returns:
matplotlib figure
- property lagged_target_positions: Dict[int, LongTensor]#
Positions of lagged target variable(s) in covariates.
- Returns:
dictionary mapping integer lags to tensor of variable positions.
- Return type:
Dict[int, torch.LongTensor]
- property target_positions: LongTensor#
Positions of target variable(s) in covariates.
- Returns:
tensor of positions.
- Return type:
torch.LongTensor