BaseModel¶
- class pytorch_forecasting.models.base_model.BaseModel(log_interval: Union[int, float] = - 1, log_val_interval: Optional[Union[float, int]] = None, learning_rate: Union[float, List[float]] = 0.001, log_gradient_flow: bool = False, loss: pytorch_forecasting.metrics.Metric = SMAPE(), logging_metrics: torch.nn.modules.container.ModuleList = ModuleList(), reduce_on_plateau_patience: int = 1000, reduce_on_plateau_min_lr: float = 1e-05, weight_decay: float = 0.0, optimizer_params: Optional[Dict[str, Any]] = None, monotone_constaints: Dict[str, int] = {}, output_transformer: Optional[Callable] = None, optimizer='ranger')[source]¶
Bases:
pytorch_lightning.core.lightning.LightningModule
BaseModel from which new timeseries models should inherit from. The
hparams
of the created object will default to the parameters indicated in__init__()
.The
forward()
method should return a named tuple with at least the entryprediction
that contains the network’s output. See the function’s documentation for more details.The idea of the base model is that common methods do not have to be re-implemented for every new architecture. The class is a [LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html) and follows its conventions. However, there are important additions:
You need to specify a
loss
attribute that stores the function to calculate theMultiHorizonLoss
for backpropagation.The
from_dataset()
method can be used to initialize a network using the specifications of a dataset. Often, parameters such as the number of features can be easily deduced from the dataset. Further, the method will also store how to rescale normalized predictions into the unnormalized prediction space. Override it to pass additional arguments to the __init__ method of your network that depend on your dataset.The
transform_output()
method rescales the network output using the target normalizer from thedataset.The
step()
method takes care of calculating the loss, logging additional metrics defined in thelogging_metrics
attribute and plots of sample predictions. You can override this method to add custom interpretations or pass extra arguments to the networks forward method.The
epoch_end()
method can be used to calculate summaries of each epoch such as statistics on the encoder length, etc.The
predict()
method makes predictions using a dataloader or dataset. Override it if you need to pass additional arguments toforward
by default.
To implement your own architecture, it is best to go throught the Using custom data and implementing custom models and to look at existing ones to understand what might be a good approach.
Example
class Network(BaseModel): def __init__(self, my_first_parameter: int=2, loss=SMAPE()): self.save_hyperparameters() super().__init__(loss=loss) def forward(self, x): normalized_prediction = self.module(x) prediction = self.transform_output(prediction=normalized_prediction, target_scale=x["target_scale"]) return self.to_network_output(prediction=prediction)
BaseModel for timeseries forecasting from which to inherit from
- Parameters
log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.
log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.
learning_rate (float, optional) – Learning rate. Defaults to 1e-3.
log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.
loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().
logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].
reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000
reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateua learning rate scheduler. Defaults to 1e-5
weight_decay (float) – weight decay. Defaults to 0.0.
optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.
monotone_constaints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g.
"0"
for first position) to constraint (-1
for negative and+1
for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to
lambda out: out["prediction"]
.optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in
torch.optim
. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay) Defaults to “ranger”.
Methods
Configure optimizers.
create_log
(x, y, out, batch_idx[, ...])Create the log used in the training and validation step.
deduce_default_output_parameters
(dataset, kwargs)Deduce default parameters for output for from_dataset() method.
epoch_end
(outputs)Run at epoch end for training or validation.
forward
(x)Network forward pass.
from_dataset
(dataset, **kwargs)Create model from dataset, i.e. save dataset parameters in model.
log_gradient_flow
(named_parameters)log distribution of gradients to identify exploding / vanishing gradients
log_metrics
(x, y, out[, prediction_kwargs])Log metrics every training/validation step.
log_prediction
(x, out, batch_idx, **kwargs)Log metrics every training/validation step.
Log gradient flow for debugging.
on_load_checkpoint
(checkpoint)Called by Lightning to restore your model.
on_save_checkpoint
(checkpoint)Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
plot_prediction
(x, out[, idx, ...])Plot prediction of prediction vs actuals
predict
(data[, mode, return_index, ...])Run inference / prediction.
predict_dependency
(data, variable, values[, ...])Predict partial dependency.
size
()get number of parameters in model
step
(x, y, batch_idx, **kwargs)Run for each train/val step.
test_epoch_end
(outputs)Called at the end of a test epoch with the output of all test steps.
test_step
(batch, batch_idx)Operates on a single batch of data from the test set.
to_network_output
(**results)Convert output into a named (and immuatable) tuple.
to_prediction
(out[, use_metric])Convert output to prediction using the loss metric.
to_quantiles
(out[, use_metric])Convert output to quantiles using the loss metric.
training_epoch_end
(outputs)Called at the end of the training epoch with the outputs of all training steps.
training_step
(batch, batch_idx)Train on batch.
transform_output
(prediction, target_scale)Extract prediction from network output and rescale it to real space / de-normalize it.
validation_epoch_end
(outputs)Called at the end of the validation epoch with the outputs of all validation steps.
validation_step
(batch, batch_idx)Operates on a single batch of data from the validation set.
- configure_optimizers()[source]¶
Configure optimizers.
Uses single Ranger optimizer. Depending if learning rate is a list or a single float, implement dynamic learning rate scheduler or deterministic version
- Returns
first entry is list of optimizers and second is list of schedulers
- Return type
Tuple[List]
- create_log(x: Dict[str, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor], out: Dict[str, torch.Tensor], batch_idx: int, prediction_kwargs: Dict[str, Any] = {}, quantiles_kwargs: Dict[str, Any] = {}) Dict[str, Any] [source]¶
Create the log used in the training and validation step.
- Parameters
x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader
y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader
out (Dict[str, torch.Tensor]) – output of the network
batch_idx (int) – batch number
prediction_kwargs (Dict[str, Any], optional) – arguments to pass to
to_prediction()
. Defaults to {}.quantiles_kwargs (Dict[str, Any], optional) –
to_quantiles()
. Defaults to {}.
- Returns
log dictionary to be returned by training and validation steps
- Return type
Dict[str, Any]
- static deduce_default_output_parameters(dataset: pytorch_forecasting.data.timeseries.TimeSeriesDataSet, kwargs: Dict[str, Any], default_loss: Optional[pytorch_forecasting.metrics.MultiHorizonMetric] = None) Dict[str, Any] [source]¶
Deduce default parameters for output for from_dataset() method.
Determines
output_size
andloss
parameters.- Parameters
dataset (TimeSeriesDataSet) – timeseries dataset
kwargs (Dict[str, Any]) – current hyperparameters
default_loss (MultiHorizonMetric, optional) – default loss function. Defaults to
MAE
.
- Returns
dictionary with
output_size
andloss
.- Return type
Dict[str, Any]
- epoch_end(outputs)[source]¶
Run at epoch end for training or validation. Can be overriden in models.
- forward(x: Dict[str, Union[torch.Tensor, List[torch.Tensor]]]) Dict[str, Union[torch.Tensor, List[torch.Tensor]]] [source]¶
Network forward pass.
- Parameters
x (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]) – network input (x as returned by the dataloader). See
to_dataloader()
method that returns a tuple ofx
andy
. This function expectsx
.- Returns
- network outputs / dictionary of tensors or list
of tensors. Create it using the
to_network_output()
method. The minimal required entries in the dictionary are (and shapes in brackets):prediction
(batch_size x n_decoder_time_steps x n_outputs or list thereof with each entry for a different target): re-scaled predictions that can be fed to metric. List of tensors if multiple targets are predicted at the same time.
Before passing outputting the predictions, you want to rescale them into real space. By default, you can use the
transform_output()
method to achieve this.
- Return type
NamedTuple[Union[torch.Tensor, List[torch.Tensor]]]
Example
def forward(self, x: # x is a batch generated based on the TimeSeriesDataset, here we just use the # continuous variables for the encoder network_input = x["encoder_cont"].squeeze(-1) prediction = self.linear(network_input) # # rescale predictions into target space prediction = self.transform_output(prediction, target_scale=x["target_scale"]) # We need to return a dictionary that at least contains the prediction # The parameter can be directly forwarded from the input. # The conversion to a named tuple can be directly achieved with the `to_network_output` function. return self.to_network_output(prediction=prediction)
- classmethod from_dataset(dataset: pytorch_forecasting.data.timeseries.TimeSeriesDataSet, **kwargs) pytorch_lightning.core.lightning.LightningModule [source]¶
Create model from dataset, i.e. save dataset parameters in model
This function should be called as
super().from_dataset()
in a derived models that implement it- Parameters
dataset (TimeSeriesDataSet) – timeseries dataset
- Returns
Model that can be trained
- Return type
- log_gradient_flow(named_parameters: Dict[str, torch.Tensor]) None [source]¶
log distribution of gradients to identify exploding / vanishing gradients
- log_metrics(x: Dict[str, torch.Tensor], y: torch.Tensor, out: Dict[str, torch.Tensor], prediction_kwargs: Optional[Dict[str, Any]] = None) None [source]¶
Log metrics every training/validation step.
- Parameters
x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader
y (torch.Tensor) – y as passed to the loss function by the dataloader
out (Dict[str, torch.Tensor]) – output of the network
prediction_kwargs (Dict[str, Any]) – parameters for
to_prediction()
of the loss metric.
- log_prediction(x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], batch_idx: int, **kwargs) None [source]¶
Log metrics every training/validation step.
- Parameters
x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader
out (Dict[str, torch.Tensor]) – output of the network
batch_idx (int) – current batch index
**kwargs – paramters to pass to
plot_prediction
- on_load_checkpoint(checkpoint: Dict[str, Any]) None [source]¶
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.- Parameters
checkpoint – Loaded checkpoint
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- on_save_checkpoint(checkpoint: Dict[str, Any]) None [source]¶
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters
checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- plot_prediction(x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], idx: int = 0, add_loss_to_title: Union[pytorch_forecasting.metrics.Metric, torch.Tensor, bool] = False, show_future_observed: bool = True, ax=None, quantiles_kwargs: Optional[Dict[str, Any]] = None, prediction_kwargs: Optional[Dict[str, Any]] = None) matplotlib.figure.Figure [source]¶
Plot prediction of prediction vs actuals
- Parameters
x – network input
out – network output
idx – index of prediction to plot
add_loss_to_title – if to add loss to title or loss function to calculate. Can be either metrics, bool indicating if to use loss metric or tensor which contains losses for all samples. Calcualted losses are determined without weights. Default to False.
show_future_observed – if to show actuals for future. Defaults to True.
ax – matplotlib axes to plot on
quantiles_kwargs (Dict[str, Any]) – parameters for
to_quantiles()
of the loss metric.prediction_kwargs (Dict[str, Any]) – parameters for
to_prediction()
of the loss metric.
- Returns
matplotlib figure
- predict(data: Union[torch.utils.data.dataloader.DataLoader, pandas.core.frame.DataFrame, pytorch_forecasting.data.timeseries.TimeSeriesDataSet], mode: Union[str, Tuple[str, str]] = 'prediction', return_index: bool = False, return_decoder_lengths: bool = False, batch_size: int = 64, num_workers: int = 0, fast_dev_run: bool = False, show_progress_bar: bool = False, return_x: bool = False, mode_kwargs: Optional[Dict[str, Any]] = None, **kwargs)[source]¶
Run inference / prediction.
- Parameters
dataloader – dataloader, dataframe or dataset
mode – one of “prediction”, “quantiles”, or “raw”, or tuple
("raw", output_name)
where output_name is a name in the dictionary returned byforward()
return_index – if to return the prediction index (in the same order as the output, i.e. the row of the dataframe corresponds to the first dimension of the output and the given time index is the time index of the first prediction)
return_decoder_lengths – if to return decoder_lengths (in the same order as the output
batch_size – batch size for dataloader - only used if data is not a dataloader is passed
num_workers – number of workers for dataloader - only used if data is not a dataloader is passed
fast_dev_run – if to only return results of first batch
show_progress_bar – if to show progress bar. Defaults to False.
return_x – if to return network inputs (in the same order as prediction output)
mode_kwargs (Dict[str, Any]) – keyword arguments for
to_prediction()
orto_quantiles()
for modes “prediction” and “quantiles”**kwargs – additional arguments to network’s forward method
- Returns
- some elements might not be present depending on what is configured
to be returned
- Return type
output, x, index, decoder_lengths
- predict_dependency(data: Union[torch.utils.data.dataloader.DataLoader, pandas.core.frame.DataFrame, pytorch_forecasting.data.timeseries.TimeSeriesDataSet], variable: str, values: Iterable, mode: str = 'dataframe', target='decoder', show_progress_bar: bool = False, **kwargs) Union[numpy.ndarray, torch.Tensor, pandas.core.series.Series, pandas.core.frame.DataFrame] [source]¶
Predict partial dependency.
- Parameters
data (Union[DataLoader, pd.DataFrame, TimeSeriesDataSet]) – data
variable (str) – variable which to modify
values (Iterable) – array of values to probe
mode (str, optional) –
Output mode. Defaults to “dataframe”. Either
”series”: values are average prediction and index are probed values
- ”dataframe”: columns are as obtained by the dataset.x_to_index() method,
prediction (which is the mean prediction over the time horizon), normalized_prediction (which are predictions devided by the prediction for the first probed value) the variable name for the probed values
”raw”: outputs a tensor of shape len(values) x prediction_shape
target – Defines which values are overwritten for making a prediction. Same as in
set_overwrite_values()
. Defaults to “decoder”.show_progress_bar – if to show progress bar. Defaults to False.
**kwargs – additional kwargs to
predict()
method
- Returns
output
- Return type
Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]
- step(x: Dict[str, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, **kwargs) Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] [source]¶
Run for each train/val step.
- Parameters
x (Dict[str, torch.Tensor]) – x as passed to the network by the dataloader
y (Tuple[torch.Tensor, torch.Tensor]) – y as passed to the loss function by the dataloader
batch_idx (int) – batch number
**kwargs – additional arguments to pass to the network apart from
x
- Returns
- tuple where the first
entry is a dictionary to which additional logging results can be added for consumption in the
epoch_end
hook and the second entry is the model’s output.
- Return type
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
- test_epoch_end(outputs)[source]¶
Called at the end of a test epoch with the output of all test steps.
# the pseudocode for these calls test_outs = [] for test_batch in test_data: out = test_step(test_batch) test_outs.append(out) test_epoch_end(test_outs)
- Parameters
outputs – List of outputs you defined in
test_step_end()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader- Returns
None
Note
If you didn’t define a
test_step()
, this won’t be called.Examples
With a single dataloader:
def test_epoch_end(self, outputs): # do something with the outputs of all test batches all_test_preds = test_step_outputs.predictions some_result = calc_all_results(all_test_preds) self.log(some_result)
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each test step for that dataloader.
def test_epoch_end(self, outputs): final_value = 0 for dataloader_outputs in outputs: for test_step_out in dataloader_outputs: # do something final_value += test_step_out self.log("final_metric", final_value)
- test_step(batch, batch_idx)[source]¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
# the pseudocode for these calls test_outs = [] for test_batch in test_data: out = test_step(test_batch) test_outs.append(out) test_epoch_end(test_outs)
- Parameters
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int) – The index of this batch.
dataloader_idx (int) – The index of the dataloader that produced this batch (only if multiple test dataloaders used).
- Returns
Any of.
Any object or value
None
- Testing will skip to the next batch
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()
will have an additional argument.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- to_network_output(**results)[source]¶
Convert output into a named (and immuatable) tuple.
This allows tracing the modules as graphs and prevents modifying the output.
- Returns
named tuple
- to_prediction(out: Dict[str, Any], use_metric: bool = True, **kwargs)[source]¶
Convert output to prediction using the loss metric.
- Parameters
out (Dict[str, Any]) – output of network where “prediction” has been transformed with
transform_output()
use_metric (bool) – if to use metric to convert for conversion, if False, simply take the average over
out["prediction"]
**kwargs – arguments to metric
to_quantiles
method
- Returns
predictions of shape batch_size x timesteps
- Return type
torch.Tensor
- to_quantiles(out: Dict[str, Any], use_metric: bool = True, **kwargs)[source]¶
Convert output to quantiles using the loss metric.
- Parameters
out (Dict[str, Any]) – output of network where “prediction” has been transformed with
transform_output()
use_metric (bool) – if to use metric to convert for conversion, if False, simply take the quantiles over
out["prediction"]
**kwargs – arguments to metric
to_quantiles
method
- Returns
quantiles of shape batch_size x timesteps x n_quantiles
- Return type
torch.Tensor
- training_epoch_end(outputs)[source]¶
Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs returned by
training_step()
.# the pseudocode for these calls train_outs = [] for train_batch in train_data: out = training_step(train_batch) train_outs.append(out) training_epoch_end(train_outs)
- Parameters
outputs – List of outputs you defined in
training_step()
. If there are multiple optimizers, it is a list containing a list of outputs for each optimizer. If usingtruncated_bptt_steps > 1
, each element is a list of outputs corresponding to the outputs of each processed split batch.- Returns
None
Note
If this method is not overridden, this won’t be called.
def training_epoch_end(self, training_step_outputs): # do something with all training_step outputs for out in training_step_outputs: ...
- transform_output(prediction: Union[torch.Tensor, List[torch.Tensor]], target_scale: Union[torch.Tensor, List[torch.Tensor]]) torch.Tensor [source]¶
Extract prediction from network output and rescale it to real space / de-normalize it.
- Parameters
prediction (Union[torch.Tensor, List[torch.Tensor]]) – normalized prediction
target_scale (Union[torch.Tensor, List[torch.Tensor]]) – scale to rescale prediction
- Returns
rescaled prediction
- Return type
torch.Tensor
- validation_epoch_end(outputs)[source]¶
Called at the end of the validation epoch with the outputs of all validation steps.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
outputs – List of outputs you defined in
validation_step()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.- Returns
None
Note
If you didn’t define a
validation_step()
, this won’t be called.Examples
With a single dataloader:
def validation_epoch_end(self, val_step_outputs): for out in val_step_outputs: ...
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.
def validation_epoch_end(self, outputs): for dataloader_output_result in outputs: dataloader_outs = dataloader_output_result.dataloader_i_outputs self.log("final_metric", final_value)
- validation_step(batch, batch_idx)[source]¶
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int) – The index of this batch
dataloader_idx (int) – The index of the dataloader that produced this batch (only if multiple val dataloaders used)
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- property current_stage: str¶
Available inside lightning loops. :return: current trainer stage. One of [“train”, “val”, “test”, “predict”, “sanity_check”]
- property log_interval: float¶
Log interval depending if training or validating
- property n_targets: int¶
Number of targets to forecast.
Based on loss function.
- Returns
number of targets
- Return type
int
- property target_names: List[str]¶
List of targets that are predicted.
- Returns
list of target names
- Return type
List[str]