Source code for pytorch_forecasting.models.nbeats._grid_callback

from lightning.pytorch.callbacks import Callback


[docs] class GridUpdateCallback(Callback): """ Custom callback to update the grid of the model during training at regular intervals. Parameters ---------- update_interval : int The frequency at which the grid is updated. Examples -------- See the full example in: `examples/nbeats_with_kan.py` """ def __init__(self, update_interval): self.update_interval = update_interval
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """ Hook called at the end of each training batch. Updates the grid of KAN layers if the current step is a multiple of the update interval. Parameters ---------- trainer : Trainer The PyTorch Lightning Trainer object. pl_module : LightningModule The model being trained (LightningModule). outputs : Any Outputs from the model for the current batch. batch : Any The current batch of data. batch_idx : int Index of the current batch. """ # Check if the current step is a multiple of the update interval if (trainer.global_step + 1) % self.update_interval == 0: # Call the model's update_kan_grid method pl_module.update_kan_grid()