pytorch_forecasting.models.timexer._timexer_v2.TimeXer#
- class pytorch_forecasting.models.timexer._timexer_v2.TimeXer(loss: Module, enc_in: int = None, hidden_size: int = 512, n_heads: int = 8, e_layers: int = 2, d_ff: int = 2048, dropout: float = 0.1, patch_length: int = 4, factor: int = 5, activation: str = 'relu', use_efficient_attention: bool = False, logging_metrics: list[Module] | None = None, optimizer: Optimizer | str | None = 'adam', optimizer_params: dict | None = None, lr_scheduler: str | None = None, lr_scheduler_params: dict | None = None, metadata: dict | None = None, **kwargs: Any)[source]#
An implementation of TimeXer model for v2 of pytorch-forecasting.
TimeXer empowers the canonical transformer with the ability to reconcile endogenous and exogenous information without any architectural modifications and achieves consistent state-of-the-art performance across twelve real-world forecasting benchmarks.
TimeXer employs patch-level and variate-level representations respectively for endogenous and exogenous variables, with an endogenous global token as a bridge in-between. With this design, TimeXer can jointly capture intra-endogenous temporal dependencies and exogenous-to-endogenous correlations.
- Parameters:
loss (nn.Module) – Loss function to use for training.
enc_in (int, optional) – Number of input features for the encoder. If not provided, it will be set to the number of continuous features in the dataset.
hidden_size (int, default=512) – Dimension of the model embeddings and hidden representations of features.
n_heads (int, default=8) – Number of attention heads in the multi-head attention mechanism. e_layers: int, default=2 Number of encoder layers in the transformer architecture.
d_ff (int, default=2048) – Dimension of the feed-forward network in the transformer architecture.
dropout (float, default=0.1) – Dropout rate for regularization. This is used throughout the model to prevent overfitting.
patch_length (int, default=24) – Length of each non-overlapping patch for endogenous variable tokenization.
factor (int, default=5) – Factor for the attention mechanism, controlling the number of keys and values.
activation (str, default='relu') – Activation function to use in the feed-forward network. Common choices are ‘relu’, ‘gelu’, etc.
use_efficient_attention (bool, default=False) – If set to True, will use PyTorch’s native, optimized Scaled Dot Product Attention implementation which can reduce computation time and memory consumption for longer sequences. PyTorch automatically selects the optimal backend (FlashAttention-2, Memory-Efficient Attention, or their own C++ implementation) based on user’s input properties, hardware capabilities, and build configuration.
logging_metrics (Optional[list[nn.Module]], default=None) – List of metrics to log during training, validation, and testing.
optimizer (Optional[Union[Optimizer, str]], default='adam') – Optimizer to use for training. Can be a string name or an instance of an optimizer.
optimizer_params (Optional[dict], default=None) – Parameters for the optimizer. If None, default parameters for the optimizer will be used.
lr_scheduler (Optional[str], default=None) – Learning rate scheduler to use. If None, no scheduler is used.
lr_scheduler_params (Optional[dict], default=None) – Parameters for the learning rate scheduler. If None, default parameters for the scheduler will be used.
metadata (Optional[dict], default=None) – Metadata for the model from TslibDataModule. This can include information about the dataset, such as the number of time steps, number of features, etc. It is used to initialize the model and ensure it is compatible with the data being used, including the split between endogenous (target) and exogenous covariates.
References
[1] https://arxiv.org/abs/2402.19072 [2] thuml/TimeXer
Notes
- [1] This implementation handles only continuous variables in the context length. Categorical variables
support will be added in the future.
- [2] The TimeXer model obtains many of its attributes from the TslibBaseModel class, which is a base class
where a lot of the boiler plate code for metadata handling and model initialization is implemented.
- __init__(loss: Module, enc_in: int = None, hidden_size: int = 512, n_heads: int = 8, e_layers: int = 2, d_ff: int = 2048, dropout: float = 0.1, patch_length: int = 4, factor: int = 5, activation: str = 'relu', use_efficient_attention: bool = False, logging_metrics: list[Module] | None = None, optimizer: Optimizer | str | None = 'adam', optimizer_params: dict | None = None, lr_scheduler: str | None = None, lr_scheduler_params: dict | None = None, metadata: dict | None = None, **kwargs: Any)[source]#
Methods
__call__(*args, **kwargs)Call self as a function.
__delattr__(name)Implement delattr(self, name).
__dir__()Default dir() implementation.
__eq__(value, /)Return self==value.
__format__(format_spec, /)Default object formatter.
__ge__(value, /)Return self>=value.
__getattr__(name)__getattribute__(name, /)Return getattr(self, name).
__getstate__()Helper for pickle.
__gt__(value, /)Return self>value.
__hash__()Return hash(self).
__init_subclass__This method is called when a class is subclassed.
__le__(value, /)Return self<=value.
__lt__(value, /)Return self<value.
__ne__(value, /)Return self!=value.
__new__(*args, **kwargs)__reduce__()Helper for pickle.
__reduce_ex__(protocol, /)Helper for pickle.
__repr__()Return repr(self).
__setattr__(name, value)Implement setattr(self, name, value).
__setstate__(state)__sizeof__()Size of object in memory, in bytes.
__str__()Return str(self).
__subclasshook__Abstract classes can override this to customize issubclass().
_apply(fn[, recurse])_apply_batch_transfer_handler(batch[, ...])_call_batch_hook(hook_name, *args)_call_impl(*args, **kwargs)_forecast(x)Forward pass of the TimeXer model.
_get_backward_hooks()Return the backward hooks for use in the call function.
_get_backward_pre_hooks()_get_name()_get_optimizer()Get the optimizer based on the specified optimizer name and parameters.
_get_scheduler(optimizer)Get the lr scheduler based on the specified scheduler name and params.
_init_network()Initialize the network for TimeXer's architecture.
_load_from_state_dict(state_dict, prefix, ...)Copy parameters and buffers from
state_dictinto only this module, but not its descendants._log_dict_through_fabric(dictionary[, logger])_maybe_warn_non_full_backward_hook(inputs, ...)_named_members(get_members_fn[, prefix, ...])Help yield various names + members of modules.
_on_before_batch_transfer(batch[, ...])_pkg()Package containing the model.
_register_load_state_dict_pre_hook(hook[, ...])See
register_load_state_dict_pre_hook()for details._register_state_dict_hook(hook)Register a post-hook for the
state_dict()method._replicate_for_data_parallel()_save_to_state_dict(destination, prefix, ...)Save module state to the destination dictionary.
_set_hparams(hp)_slow_forward(*input, **kwargs)_to_hparams_dict(hp)_verify_is_manual_optimization(fn_name)_wrapped_call_impl(*args, **kwargs)add_module(name, module)Add a child module to the current module.
all_gather(data[, group, sync_grads])Gather tensors or collections of tensors from multiple processes.
apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.backward(loss, *args, **kwargs)Called to perform backward on the loss returned in
training_step().bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
children()Return an iterator over immediate children modules.
clip_gradients(optimizer[, ...])Handles gradient clipping internally.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().configure_callbacks()Configure model-specific callbacks.
configure_gradient_clipping(optimizer[, ...])Perform gradient clipping for the optimizer parameters.
configure_model()Hook to create modules in a strategy and precision aware context.
configure_optimizers()Configure the optimizer and learning rate scheduler.
configure_sharded_model()Deprecated.
cpu()See
torch.nn.Module.cpu().cuda([device])Moves all model parameters and buffers to the GPU.
double()See
torch.nn.Module.double().eval()Set the module in evaluation mode.
extra_repr()Return the extra representation of the module.
float()See
torch.nn.Module.float().forward(x)Forward pass of the TimeXer model.
freeze()Freeze all params for inference.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()See
torch.nn.Module.half().ipu([device])Move all model parameters and buffers to the IPU.
load_from_checkpoint(checkpoint_path[, ...])Primary way of loading a model from a checkpoint.
load_state_dict(state_dict[, strict, assign])Copy parameters and buffers from
state_dictinto this module and its descendants.log(name, value[, prog_bar, logger, ...])Log a key, value pair.
log_dict(dictionary[, prog_bar, logger, ...])Log a dictionary of values at once.
log_metrics(y_hat, y[, prefix])Log additional metrics during training, validation, or testing.
lr_scheduler_step(scheduler, metric)Override this method to adjust the default way the
Trainercalls each scheduler.lr_schedulers()Returns the learning rate scheduler(s) that are being used during training.
manual_backward(loss, *args, **kwargs)Call this directly from your
training_step()when doing optimizations manually.modules([remove_duplicate])Return an iterator over all modules in the network.
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
on_after_backward()Called after
loss.backward()and before optimizers are stepped.on_after_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_backward(loss)Called before
loss.backward().on_before_batch_transfer(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_before_optimizer_step(optimizer)Called before
optimizer.step().on_before_zero_grad(optimizer)Called after
training_step()and beforeoptimizer.zero_grad().on_fit_end()Called at the very end of fit.
on_fit_start()Called at the very beginning of fit.
on_load_checkpoint(checkpoint)Called by Lightning to restore your model.
on_predict_batch_end(outputs, batch, batch_idx)Called in the predict loop after the batch.
on_predict_batch_start(batch, batch_idx[, ...])Called in the predict loop before anything happens for that batch.
on_predict_end()Called at the end of predicting.
on_predict_epoch_end()Called at the end of predicting.
on_predict_epoch_start()Called at the beginning of predicting.
on_predict_model_eval()Called when the predict loop starts.
on_predict_start()Called at the beginning of predicting.
on_save_checkpoint(checkpoint)Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end(outputs, batch, batch_idx)Called in the test loop after the batch.
on_test_batch_start(batch, batch_idx[, ...])Called in the test loop before anything happens for that batch.
on_test_end()Called at the end of testing.
on_test_epoch_end()Called in the test loop at the very end of the epoch.
on_test_epoch_start()Called in the test loop at the very beginning of the epoch.
on_test_model_eval()Called when the test loop starts.
on_test_model_train()Called when the test loop ends.
on_test_start()Called at the beginning of testing.
on_train_batch_end(outputs, batch, batch_idx)Called in the training loop after the batch.
on_train_batch_start(batch, batch_idx)Called in the training loop before anything happens for that batch.
on_train_end()Called at the end of training before logger experiment is closed.
on_train_epoch_end()Called in the training loop at the very end of the epoch.
on_train_epoch_start()Called in the training loop at the very beginning of the epoch.
on_train_start()Called at the beginning of training after sanity check.
on_validation_batch_end(outputs, batch, ...)Called in the validation loop after the batch.
on_validation_batch_start(batch, batch_idx)Called in the validation loop before anything happens for that batch.
on_validation_end()Called at the end of validation.
on_validation_epoch_end()Called in the validation loop at the very end of the epoch.
on_validation_epoch_start()Called in the validation loop at the very beginning of the epoch.
on_validation_model_eval()Called when the validation loop starts.
on_validation_model_train()Called when the validation loop ends.
on_validation_model_zero_grad()Called by the training loop to release gradients before entering the validation loop.
on_validation_start()Called at the beginning of validation.
optimizer_step(epoch, batch_idx, optimizer)Override this method to adjust the default way the
Trainercalls the optimizer.optimizer_zero_grad(epoch, batch_idx, optimizer)Override this method to change the default behaviour of
optimizer.zero_grad().optimizers([use_pl_optimizer])Returns the optimizer(s) that are being used during training.
parameters([recurse])Return an iterator over module parameters.
predict(dataloader[, mode, return_info, ...])Generate predictions for new data using the lightning.Trainer.
predict_dataloader()An iterable or collection of iterables specifying prediction samples.
predict_step(batch, batch_idx[, dataloader_idx])Prediction step for the model.
prepare_data()Use this to download and prepare data.
print(*args, **kwargs)Prints only from process 0.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.remove_ignored_hparams(ignore_list)Remove ignored hyperparameters from the stored state.
requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
save_hyperparameters(*args[, ignore, frame, ...])Save arguments to
hparamsattribute.set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module[, strict])Set the submodule given by
targetif it exists, otherwise throw an error.setup(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
share_memory()See
torch.Tensor.share_memory_().state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
teardown(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader()An iterable or collection of iterables specifying test samples.
test_step(batch, batch_idx)Test step for the model.
to(*args, **kwargs)See
torch.nn.Module.to().to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
to_onnx([file_path, input_sample])Saves the model in ONNX format.
to_prediction(out, **kwargs)Converts raw model output to point forecasts.
to_quantiles(out, **kwargs)Converts raw model output to quantile forecasts.
to_tensorrt([file_path, input_sample, ir, ...])Export the model to ScriptModule or GraphModule using TensorRT compile backend.
to_torchscript([file_path, method, ...])By default compiles the whole model to a
torch.jit.ScriptModule.toggle_optimizer(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
toggled_optimizer(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
train([mode])Set the module in training mode.
train_dataloader()An iterable or collection of iterables specifying training samples.
training_step(batch, batch_idx)Training step for the model.
transfer_batch_to_device(batch, device, ...)Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.transform_output(y_hat, target_scale)Transform the output of the model to the original scale.
type(dst_type)See
torch.nn.Module.type().unfreeze()Unfreeze all parameters for training.
untoggle_optimizer(optimizer)Resets the state of required gradients that were toggled with
toggle_optimizer().val_dataloader()An iterable or collection of iterables specifying validation samples.
validation_step(batch, batch_idx)Validation step for the model.
xpu([device])Move all model parameters and buffers to the XPU.
zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
CHECKPOINT_HYPER_PARAMS_KEYCHECKPOINT_HYPER_PARAMS_NAMECHECKPOINT_HYPER_PARAMS_TYPET_destination_OPTIMIZER_REGISTRY_SCHEDULER_REGISTRY__annotations____dict____doc____jit_unused_properties____module____weakref__list of weak references to the object
_compiled_call_impl_jit_is_scripting_versionThis allows better BC support for
load_state_dict().automatic_optimizationIf set to
Falseyou are responsible for calling.backward(),.step(),.zero_grad().call_super_initcurrent_epochThe current epoch in the
Trainer, or 0 if not attached.devicedevice_meshStrategies like
ModelParallelStrategywill create a device mesh that can be accessed in theconfigure_model()hook to parallelize the LightningModule.dtypedump_patchesexample_input_arrayThe example input array is a specification of what the module can consume in the
forward()method.fabricglobal_rankThe index of the current process across all nodes and devices.
global_stepTotal training batches seen across all epochs.
hparamsThe collection of hyperparameters saved with
save_hyperparameters().hparams_initialThe collection of hyperparameters saved with
save_hyperparameters().local_rankThe index of the current process within a single node.
loggerReference to the logger object in the Trainer.
loggersReference to the list of loggers in the Trainer.
on_gpuReturns
Trueif this model is currently located on a GPU.strict_loadingDetermines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).
trainertraining_parameters_buffers_non_persistent_buffers_set_backward_pre_hooks_backward_hooks_is_full_backward_hook_forward_hooks_forward_hooks_with_kwargs_forward_hooks_always_called_forward_pre_hooks_forward_pre_hooks_with_kwargs_state_dict_hooks_load_state_dict_pre_hooks_state_dict_pre_hooks_load_state_dict_post_hooks_modules