pytorch_forecasting.models.nbeats.NBeatsKAN#
- class pytorch_forecasting.models.nbeats.NBeatsKAN(stack_types: list[str] | None = None, num_blocks: list[int] | None = None, num_block_layers: list[int] | None = None, widths: list[int] | None = None, sharing: list[bool] | None = None, expansion_coefficient_lengths: list[int] | None = None, prediction_length: int = 1, context_length: int = 1, dropout: float = 0.1, learning_rate: float = 0.01, log_interval: int = -1, log_gradient_flow: bool = False, log_val_interval: int = None, weight_decay: float = 0.001, loss: MultiHorizonMetric = None, reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: ModuleList = None, num: int = 5, k: int = 3, noise_scale: float = 0.5, scale_base_mu: float = 0.0, scale_base_sigma: float = 1.0, scale_sp: float = 1.0, base_fun: callable = None, grid_eps: float = 0.02, grid_range: list[int] = None, sp_trainable: bool = True, sb_trainable: bool = True, sparse_init: bool = False, **kwargs)[source]#
Initialize NBeatsKAN Model - use its
from_dataset()method if possible.Based on the article N-BEATS: Neural basis expansion analysis for interpretable time series forecasting. The network has (if used as ensemble) outperformed all other methods including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably the most important benchmark for univariate time series forecasting.
The
NHiTSnetwork has recently shown to consistently outperform N-BEATS.- Parameters:
stack_types (list of str) – One of the following values: “generic”, “seasonality” or “trend”. A list of strings of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”].
num_blocks (list of int) – The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3]
num_block_layers (list of int) – Number of fully connected layers with ReLu activation per block. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4].
widths (list of int) – Widths of the fully connected layers with ReLu activation in the blocks. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512]. Recommended value for interpretable mode: [256, 2048]
sharing (list of bool) – Whether the weights are shared with the other blocks per stack. A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False]. Recommended value for interpretable mode: [True].
expansion_coefficient_lengths (list of int) – If the type is “G” (generic), then the length of the expansion coefficient. If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for interpretable mode: [3]
prediction_length (int) – Length of the prediction. Also known as ‘horizon’.
context_length (int) – Number of time units that condition the predictions. Also known as ‘lookback period’. Should be between 1-10 times the prediction length.
backcast_loss_ratio (float) – Weight of backcast in comparison to forecast when calculating the loss. A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight.
loss (MultiHorizonMetric) – Loss to optimize. Defaults to MASE().
log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures.
reduce_on_plateau_patience (int) – Patience after which learning rate is reduced by a factor of 10
logging_metrics (nn.ModuleList of MultiHorizonMetric) – List of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
num (int) – Parameter for KAN layer. the number of grid intervals = G. Default: 5.
k (int) – Parameter for KAN layer. the order of piecewise polynomial. Default: 3.
noise_scale (float) – Parameter for KAN layer. the scale of noise injected at initialization. Default: 0.1.
scale_base_mu (float) – Parameter for KAN layer. the scale of the residual function b(x) is initialized to be N(scale_base_mu, scale_base_sigma^2). Default: 0.0.
scale_base_sigma (float) – Parameter for KAN layer. the scale of the residual function b(x) is initialized to be N(scale_base_mu, scale_base_sigma^2). Default: 1.0.
scale_sp (float) – Parameter for KAN layer. the scale of the base function spline(x). Default: 1.0.
base_fun (callable) – Parameter for KAN layer. residual function b(x). Default: None.
grid_eps (float) – Parameter for KAN layer. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02.
grid_range (list of int) – Parameter for KAN layer. list/np.array of shape (2,). setting the range of grids. Default: None.
sp_trainable (bool) – Parameter for KAN layer. If true, scale_sp is trainable. Default: True.
sb_trainable (bool) – Parameter for KAN layer. If true, scale_base is trainable. Default: True.
sparse_init (bool) – Parameter for KAN layer. if sparse_init = True, sparse initialization is applied. Default: False.
**kwargs – Additional arguments to
BaseModel.
Examples
See the full example in: examples/nbeats_with_kan.py
Notes
The KAN blocks are based on the Kolmogorov-Arnold representation theorem and replace fixed MLP edge weights with learnable univariate spline functions. This allows KAN-augmented N-BEATS to better capture complex patterns, improve interpretability, and achieve parameter efficiency. Additionally, when applied in a doubly-residual adversarial framework, the model excels at zero-shot time-series forecasting across markets.
Key differences from original N-BEATS: - MLP layers are replaced by KAN layers with spline-based edge functions. - Each weight is a trainable function, not a scalar. - Enables visualization of learned functions and better domain adaptation. - Yields improved accuracy and interpretability with fewer parameters.
References
propose replacing MLP weights with spline-based learnable edge functions, enabling improved accuracy, interpretability, and scaling behavior compared to standard MLPs. .. [2] A. Bhattacharya & N. Haq (2024), “Zero Shot Time Series Forecasting Using Kolmogorov Arnold Networks” incorporate KAN layers into a doubly-residual N-BEATS architecture with adversarial domain adaptation, achieving strong zero-shot cross-market electricity price forecasting performance.
BaseModel for timeseries forecasting from which to inherit from
- Parameters:
log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.
log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.
learning_rate (float, optional) – Learning rate. Defaults to 1e-3.
log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.
loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().
logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].
reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000
reduce_on_plateau_reduction (float) – reduction in learning rate when encountering plateau. Defaults to 2.0.
reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateau learning rate scheduler. Defaults to 1e-5
weight_decay (float) – weight decay. Defaults to 0.0.
optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.
monotone_constraints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g.
"0"for first position) to constraint (-1for negative and+1for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to
lambda out: out["prediction"].optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in
torch.optimorpytorch_optimizer. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay). Defaults to “adam”.
- __init__(stack_types: list[str] | None = None, num_blocks: list[int] | None = None, num_block_layers: list[int] | None = None, widths: list[int] | None = None, sharing: list[bool] | None = None, expansion_coefficient_lengths: list[int] | None = None, prediction_length: int = 1, context_length: int = 1, dropout: float = 0.1, learning_rate: float = 0.01, log_interval: int = -1, log_gradient_flow: bool = False, log_val_interval: int = None, weight_decay: float = 0.001, loss: MultiHorizonMetric = None, reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: ModuleList = None, num: int = 5, k: int = 3, noise_scale: float = 0.5, scale_base_mu: float = 0.0, scale_base_sigma: float = 1.0, scale_sp: float = 1.0, base_fun: callable = None, grid_eps: float = 0.02, grid_range: list[int] = None, sp_trainable: bool = True, sb_trainable: bool = True, sparse_init: bool = False, **kwargs)[source]#
BaseModel for timeseries forecasting from which to inherit from
- Parameters:
log_interval (Union[int, float], optional) – Batches after which predictions are logged. If < 1.0, will log multiple entries per batch. Defaults to -1.
log_val_interval (Union[int, float], optional) – batches after which predictions for validation are logged. Defaults to None/log_interval.
learning_rate (float, optional) – Learning rate. Defaults to 1e-3.
log_gradient_flow (bool) – If to log gradient flow, this takes time and should be only done to diagnose training failures. Defaults to False.
loss (Metric, optional) – metric to optimize, can also be list of metrics. Defaults to SMAPE().
logging_metrics (nn.ModuleList[MultiHorizonMetric]) – list of metrics that are logged during training. Defaults to [].
reduce_on_plateau_patience (int) – patience after which learning rate is reduced by a factor of 10. Defaults to 1000
reduce_on_plateau_reduction (float) – reduction in learning rate when encountering plateau. Defaults to 2.0.
reduce_on_plateau_min_lr (float) – minimum learning rate for reduce on plateau learning rate scheduler. Defaults to 1e-5
weight_decay (float) – weight decay. Defaults to 0.0.
optimizer_params (Dict[str, Any]) – additional parameters for the optimizer. Defaults to {}.
monotone_constraints (Dict[str, int]) – dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g.
"0"for first position) to constraint (-1for negative and+1for positive, larger numbers add more weight to the constraint vs. the loss but are usually not necessary). This constraint significantly slows down training. Defaults to {}.output_transformer (Callable) – transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to
lambda out: out["prediction"].optimizer (str) – Optimizer, “ranger”, “sgd”, “adam”, “adamw” or class name of optimizer in
torch.optimorpytorch_optimizer. Alternatively, a class or function can be passed which takes parameters as first argument and a lr argument (optionally also weight_decay). Defaults to “adam”.
Methods
__call__(*args, **kwargs)Call self as a function.
__delattr__(name)Implement delattr(self, name).
__dir__()Default dir() implementation.
__eq__(value, /)Return self==value.
__format__(format_spec, /)Default object formatter.
__ge__(value, /)Return self>=value.
__getattr__(name)__getattribute__(name, /)Return getattr(self, name).
__getstate__()Helper for pickle.
__gt__(value, /)Return self>value.
__hash__()Return hash(self).
__init_subclass__This method is called when a class is subclassed.
__le__(value, /)Return self<=value.
__lt__(value, /)Return self<value.
__ne__(value, /)Return self!=value.
__new__(*args, **kwargs)__reduce__()Helper for pickle.
__reduce_ex__(protocol, /)Helper for pickle.
__repr__()Return repr(self).
__setattr__(name, value)Implement setattr(self, name, value).
__setstate__(state)__sizeof__()Size of object in memory, in bytes.
__str__()Return str(self).
__subclasshook__Abstract classes can override this to customize issubclass().
_apply(fn[, recurse])_apply_batch_transfer_handler(batch[, ...])_call_batch_hook(hook_name, *args)_call_impl(*args, **kwargs)_get_backward_hooks()Return the backward hooks for use in the call function.
_get_backward_pre_hooks()_get_name()_load_from_state_dict(state_dict, prefix, ...)Copy parameters and buffers from
state_dictinto only this module, but not its descendants._log_dict_through_fabric(dictionary[, logger])_logger_supports(method)Whether logger supports method.
_maybe_warn_non_full_backward_hook(inputs, ...)_named_members(get_members_fn[, prefix, ...])Help yield various names + members of modules.
_on_before_batch_transfer(batch[, ...])_pkg()Package for the model.
_register_load_state_dict_pre_hook(hook[, ...])See
register_load_state_dict_pre_hook()for details._register_state_dict_hook(hook)Register a post-hook for the
state_dict()method._replicate_for_data_parallel()_save_to_state_dict(destination, prefix, ...)Save module state to the destination dictionary.
_set_hparams(hp)_slow_forward(*input, **kwargs)_to_hparams_dict(hp)_verify_is_manual_optimization(fn_name)_wrapped_call_impl(*args, **kwargs)add_module(name, module)Add a child module to the current module.
all_gather(data[, group, sync_grads])Gather tensors or collections of tensors from multiple processes.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.backward(loss, *args, **kwargs)Called to perform backward on the loss returned in
training_step().bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
clip_gradients(optimizer[, ...])Handles gradient clipping internally.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().configure_callbacks()Configure model-specific callbacks.
configure_gradient_clipping(optimizer[, ...])Perform gradient clipping for the optimizer parameters.
configure_model()Hook to create modules in a strategy and precision aware context.
configure_optimizers()Configure optimizers.
configure_sharded_model()Deprecated.
cpu()See
torch.nn.Module.cpu().create_log(x, y, out, batch_idx[, ...])Create the log used in the training and validation step.
cuda([device])Moves all model parameters and buffers to the GPU.
deduce_default_output_parameters(dataset, kwargs)Deduce default parameters for output for from_dataset() method.
double()See
torch.nn.Module.double().eval()Set the module in evaluation mode.
extra_repr()Return extra information about parameters for representation/logging.
float()See
torch.nn.Module.float().forward(x)Pass forward of network.
freeze()Freeze all params for inference.
from_dataset(dataset, **kwargs)Convenience function to create network from :py:class ~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()See
torch.nn.Module.half().ipu([device])Move all model parameters and buffers to the IPU.
load_from_checkpoint(checkpoint_path[, ...])Primary way of loading a model from a checkpoint.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.log(*args, **kwargs)See
lightning.pytorch.core.lightning.LightningModule.log().log_dict(dictionary[, prog_bar, logger, ...])Log a dictionary of values at once.
log_gradient_flow(named_parameters)log distribution of gradients to identify exploding / vanishing gradients
log_interpretation(x, out, batch_idx)Log interpretation of network predictions in tensorboard.
log_metrics(x, y, out[, prediction_kwargs])Log metrics every training/validation step.
log_prediction(x, out, batch_idx, **kwargs)Log metrics every training/validation step.
lr_scheduler_step(scheduler, metric)Override this method to adjust the default way the
Trainercalls each scheduler.lr_schedulers()Returns the learning rate scheduler(s) that are being used during training.
manual_backward(loss, *args, **kwargs)Call this directly from your
training_step()when doing optimizations manually.modules([remove_duplicate])Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
on_after_backward()Log gradient flow for debugging.
on_after_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_backward(loss)Called before
loss.backward().on_before_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_before_optimizer_step(optimizer)Called before
optimizer.step().on_before_zero_grad(optimizer)Called after
training_step()and beforeoptimizer.zero_grad().on_epoch_end(outputs)Run at epoch end for training or validation.
on_fit_end()Called at the very end of fit.
on_fit_start()Called at the very beginning of fit.
on_load_checkpoint(checkpoint)Called by Lightning to restore your model.
on_predict_batch_end(outputs, batch, batch_idx)Called in the predict loop after the batch.
on_predict_batch_start(batch, batch_idx[, ...])Called in the predict loop before anything happens for that batch.
on_predict_end()Called at the end of predicting.
on_predict_epoch_end()Called at the end of predicting.
on_predict_epoch_start()Called at the beginning of predicting.
on_predict_model_eval()Called when the predict loop starts.
on_predict_start()Called at the beginning of predicting.
on_save_checkpoint(checkpoint)Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end(outputs, batch, batch_idx)Called in the test loop after the batch.
on_test_batch_start(batch, batch_idx[, ...])Called in the test loop before anything happens for that batch.
on_test_end()Called at the end of testing.
on_test_epoch_end()Called in the test loop at the very end of the epoch.
on_test_epoch_start()Called in the test loop at the very beginning of the epoch.
on_test_model_eval()Called when the test loop starts.
on_test_model_train()Called when the test loop ends.
on_test_start()Called at the beginning of testing.
on_train_batch_end(outputs, batch, batch_idx)Called in the training loop after the batch.
on_train_batch_start(batch, batch_idx)Called in the training loop before anything happens for that batch.
on_train_end()Called at the end of training before logger experiment is closed.
on_train_epoch_end()Called in the training loop at the very end of the epoch.
on_train_epoch_start()Called in the training loop at the very beginning of the epoch.
on_train_start()Called at the beginning of training after sanity check.
on_validation_batch_end(outputs, batch, ...)Called in the validation loop after the batch.
on_validation_batch_start(batch, batch_idx)Called in the validation loop before anything happens for that batch.
on_validation_end()Called at the end of validation.
on_validation_epoch_end()Called in the validation loop at the very end of the epoch.
on_validation_epoch_start()Called in the validation loop at the very beginning of the epoch.
on_validation_model_eval()Called when the validation loop starts.
on_validation_model_train()Called when the validation loop ends.
on_validation_model_zero_grad()Called by the training loop to release gradients before entering the validation loop.
on_validation_start()Called at the beginning of validation.
optimizer_step(epoch, batch_idx, optimizer)Override this method to adjust the default way the
Trainercalls the optimizer.optimizer_zero_grad(epoch, batch_idx, optimizer)Override this method to change the default behaviour of
optimizer.zero_grad().optimizers([use_pl_optimizer])Returns the optimizer(s) that are being used during training.
parameters([recurse])Return an iterator over module parameters.
plot_interpretation(x, output, idx[, ax, ...])Plot interpretation.
plot_prediction(x, out[, idx, ...])Plot prediction of prediction vs actuals
predict(data[, mode, return_index, ...])Run inference / prediction.
predict_dataloader()An iterable or collection of iterables specifying prediction samples.
predict_dependency(data, variable, values[, ...])Predict partial dependency.
predict_step(batch, batch_idx)Step function called during
predict().prepare_data()Use this to download and prepare data.
print(*args, **kwargs)Prints only from process 0.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.remove_ignored_hparams(ignore_list)Remove ignored hyperparameters from the stored state.
requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
save_hyperparameters(*args[, ignore, frame, ...])Save arguments to
hparamsattribute.set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module[, strict])Set the submodule given by
targetif it exists, otherwise throw an error.setup(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
share_memory()See
torch.Tensor.share_memory_().size()get number of parameters in model
state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
step(x, y, batch_idx)Take training / validation step.
teardown(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader()An iterable or collection of iterables specifying test samples.
test_step(batch, batch_idx)Operates on a single batch of data from the test set.
to(*args, **kwargs)See
torch.nn.Module.to().to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
to_network_output(**results)Convert output into a named (and immutable) tuple.
to_onnx([file_path, input_sample])Saves the model in ONNX format.
to_prediction(out[, use_metric])Convert output to prediction using the loss metric.
to_quantiles(out[, use_metric])Convert output to quantiles using the loss metric.
to_tensorrt([file_path, input_sample, ir, ...])Export the model to ScriptModule or GraphModule using TensorRT compile backend.
to_torchscript([file_path, method, ...])By default compiles the whole model to a
torch.jit.ScriptModule.toggle_optimizer(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
toggled_optimizer(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
train([mode])Set the module in training mode.
train_dataloader()An iterable or collection of iterables specifying training samples.
training_step(batch, batch_idx)Train on batch.
transfer_batch_to_device(batch, device, ...)Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.transform_output(prediction, target_scale[, ...])Extract prediction from network output and rescale it to real space / de-normalize it.
type(dst_type)See
torch.nn.Module.type().unfreeze()Unfreeze all parameters for training.
untoggle_optimizer(optimizer)Resets the state of required gradients that were toggled with
toggle_optimizer().update_kan_grid()Updates grid of KAN layers when using KAN layers in NBEATSBlock.
val_dataloader()An iterable or collection of iterables specifying validation samples.
validation_step(batch, batch_idx)Operates on a single batch of data from the validation set.
xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_SPECIAL_KEYCHECKPOINT_HYPER_PARAMS_TYPET_destination__annotations____dict____doc____jit_unused_properties____module____weakref__list of weak references to the object
_compiled_call_impl_jit_is_scripting_versionThis allows better BC support for
load_state_dict().automatic_optimizationIf set to
Falseyou are responsible for calling.backward(),.step(),.zero_grad().call_super_initcurrent_epochThe current epoch in the
Trainer, or 0 if not attached.current_stageAvailable inside lightning loops.
devicedevice_meshStrategies like
ModelParallelStrategywill create a device mesh that can be accessed in theconfigure_model()hook to parallelize the LightningModule.dtypedump_patchesexample_input_arrayThe example input array is a specification of what the module can consume in the
forward()method.fabricglobal_rankThe index of the current process across all nodes and devices.
global_stepTotal training batches seen across all epochs.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().local_rankThe index of the current process within a single node.
log_intervalLog interval depending if training or validating
loggerReference to the logger object in the Trainer.
loggersReference to the list of loggers in the Trainer.
n_targetsNumber of targets to forecast.
on_gpuReturns
Trueif this model is currently located on a GPU.predictingstrict_loadingDetermines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
target_namesList of targets that are predicted.
trainertraining_parameters_buffers_non_persistent_buffers_set_backward_pre_hooks_backward_hooks_is_full_backward_hook_forward_hooks_forward_hooks_with_kwargs_forward_hooks_always_called_forward_pre_hooks_forward_pre_hooks_with_kwargs_state_dict_hooks_load_state_dict_pre_hooks_state_dict_pre_hooks_load_state_dict_post_hooks_modules