Source code for pytorch_forecasting.utils

"""
Helper functions for PyTorch forecasting
"""
from collections import namedtuple
from contextlib import redirect_stdout
import inspect
import os
from typing import Any, Callable, Dict, List, Tuple, Union

import lightning.pytorch as pl
import torch
from torch import nn
from torch.fft import irfft, rfft
import torch.nn.functional as F
from torch.nn.utils import rnn


[docs]def integer_histogram( data: torch.LongTensor, min: Union[None, int] = None, max: Union[None, int] = None ) -> torch.Tensor: """ Create histogram of integers in predefined range Args: data: data for which to create histogram min: minimum of histogram, is inferred from data by default max: maximum of histogram, is inferred from data by default Returns: histogram """ uniques, counts = torch.unique(data, return_counts=True) if min is None: min = uniques.min() if max is None: max = uniques.max() hist = torch.zeros(max - min + 1, dtype=torch.long, device=data.device).scatter( dim=0, index=uniques - min, src=counts ) return hist
[docs]def groupby_apply( keys: torch.Tensor, values: torch.Tensor, bins: int = 95, reduction: str = "mean", return_histogram: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Groupby apply for torch tensors Args: keys: tensor of groups (``0`` to ``bins``) values: values to aggregate - same size as keys bins: total number of groups reduction: either "mean" or "sum" return_histogram: if to return histogram on top Returns: tensor of size ``bins`` with aggregated values and optionally with counts of values """ if reduction == "mean": reduce = torch.mean elif reduction == "sum": reduce = torch.sum else: raise ValueError(f"Unknown reduction '{reduction}'") uniques, counts = keys.unique(return_counts=True) groups = torch.stack([reduce(item) for item in torch.split_with_sizes(values, tuple(counts))]) reduced = torch.zeros(bins, dtype=values.dtype, device=values.device).scatter(dim=0, index=uniques, src=groups) if return_histogram: hist = torch.zeros(bins, dtype=torch.long, device=values.device).scatter(dim=0, index=uniques, src=counts) return reduced, hist else: return reduced
[docs]def profile(function: Callable, profile_fname: str, filter: str = "", period=0.0001, **kwargs): """ Profile a given function with ``vmprof``. Args: function (Callable): function to profile profile_fname (str): path where to save profile (`.txt` file will be saved with line profile) filter (str, optional): filter name (e.g. module name) to filter profile. Defaults to "". period (float, optional): frequency of calling profiler in seconds. Defaults to 0.0001. """ import vmprof from vmprof.show import LinesPrinter # profiler config with open(profile_fname, "wb+") as fd: # start profiler vmprof.enable(fd.fileno(), lines=True, period=period) # run function function(**kwargs) # stop profiler vmprof.disable() # write report to disk if kwargs.get("lines", True): with open(f"{os.path.splitext(profile_fname)[0]}.txt", "w") as f: with redirect_stdout(f): LinesPrinter(filter=filter).show(profile_fname)
[docs]def get_embedding_size(n: int, max_size: int = 100) -> int: """ Determine empirically good embedding sizes (formula taken from fastai). Args: n (int): number of classes max_size (int, optional): maximum embedding size. Defaults to 100. Returns: int: embedding size """ if n > 2: return min(round(1.6 * n**0.56), max_size) else: return 1
[docs]def create_mask(size: int, lengths: torch.LongTensor, inverse: bool = False) -> torch.BoolTensor: """ Create boolean masks of shape len(lenghts) x size. An entry at (i, j) is True if lengths[i] > j. Args: size (int): size of second dimension lengths (torch.LongTensor): tensor of lengths inverse (bool, optional): If true, boolean mask is inverted. Defaults to False. Returns: torch.BoolTensor: mask """ if inverse: # return where values are return torch.arange(size, device=lengths.device).unsqueeze(0) < lengths.unsqueeze(-1) else: # return where no values are return torch.arange(size, device=lengths.device).unsqueeze(0) >= lengths.unsqueeze(-1)
_NEXT_FAST_LEN = {}
[docs]def next_fast_len(size): """ Returns the next largest number ``n >= size`` whose prime factors are all 2, 3, or 5. These sizes are efficient for fast fourier transforms. Equivalent to :func:`scipy.fftpack.next_fast_len`. Implementation from pyro :param int size: A positive number. :returns: A possibly larger number. :rtype int: """ try: return _NEXT_FAST_LEN[size] except KeyError: pass assert isinstance(size, int) and size > 0 next_size = size while True: remaining = next_size for n in (2, 3, 5): while remaining % n == 0: remaining //= n if remaining == 1: _NEXT_FAST_LEN[size] = next_size return next_size next_size += 1
[docs]def autocorrelation(input, dim=0): """ Computes the autocorrelation of samples at dimension ``dim``. Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_. :param torch.Tensor input: the input tensor. :param int dim: the dimension to calculate autocorrelation. :returns torch.Tensor: autocorrelation of ``input``. """ # Adapted from Stan implementation # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp N = input.size(dim) M = next_fast_len(N) M2 = 2 * M # transpose dim with -1 for Fourier transform input = input.transpose(dim, -1) # centering and padding x centered_signal = input - input.mean(dim=-1, keepdim=True) # Fourier transform freqvec = torch.view_as_real(rfft(centered_signal, n=M2)) # take square of magnitude of freqvec (or freqvec x freqvec*) freqvec_gram = freqvec.pow(2).sum(-1) # inverse Fourier transform autocorr = irfft(freqvec_gram, n=M2) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] autocorr = autocorr / torch.tensor(range(N, 0, -1), dtype=input.dtype, device=input.device) autocorr = autocorr / autocorr[..., :1] return autocorr.transpose(dim, -1)
[docs]def unpack_sequence(sequence: Union[torch.Tensor, rnn.PackedSequence]) -> Tuple[torch.Tensor, torch.Tensor]: """ Unpack RNN sequence. Args: sequence (Union[torch.Tensor, rnn.PackedSequence]): RNN packed sequence or tensor of which first index are samples and second are timesteps Returns: Tuple[torch.Tensor, torch.Tensor]: tuple of unpacked sequence and length of samples """ if isinstance(sequence, rnn.PackedSequence): sequence, lengths = rnn.pad_packed_sequence(sequence, batch_first=True) # batch sizes reside on the CPU by default -> we need to bring them to GPU lengths = lengths.to(sequence.device) else: lengths = torch.ones(sequence.size(0), device=sequence.device, dtype=torch.long) * sequence.size(1) return sequence, lengths
[docs]def concat_sequences( sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]] ) -> Union[torch.Tensor, rnn.PackedSequence]: """ Concatenate RNN sequences. Args: sequences (Union[List[torch.Tensor], List[rnn.PackedSequence]): list of RNN packed sequences or tensors of which first index are samples and second are timesteps Returns: Union[torch.Tensor, rnn.PackedSequence]: concatenated sequence """ if isinstance(sequences[0], rnn.PackedSequence): return rnn.pack_sequence(sequences, enforce_sorted=False) elif isinstance(sequences[0], torch.Tensor): return torch.cat(sequences, dim=1) elif isinstance(sequences[0], (tuple, list)): return tuple( concat_sequences([sequences[ii][i] for ii in range(len(sequences))]) for i in range(len(sequences[0])) ) else: raise ValueError("Unsupported sequence type")
[docs]def padded_stack( tensors: List[torch.Tensor], side: str = "right", mode: str = "constant", value: Union[int, float] = 0 ) -> torch.Tensor: """ Stack tensors along first dimension and pad them along last dimension to ensure their size is equal. Args: tensors (List[torch.Tensor]): list of tensors to stack side (str): side on which to pad - "left" or "right". Defaults to "right". mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant' value (Union[int, float]): value to use for constant padding Returns: torch.Tensor: stacked tensor """ full_size = max([x.size(-1) for x in tensors]) def make_padding(pad): if side == "left": return (pad, 0) elif side == "right": return (0, pad) else: raise ValueError(f"side for padding '{side}' is unknown") out = torch.stack( [ F.pad(x, make_padding(full_size - x.size(-1)), mode=mode, value=value) if full_size - x.size(-1) > 0 else x for x in tensors ], dim=0, ) return out
[docs]def to_list(value: Any) -> List[Any]: """ Convert value or list to list of values. If already list, return object directly Args: value (Any): value to convert Returns: List[Any]: list of values """ if isinstance(value, (tuple, list)) and not isinstance(value, rnn.PackedSequence): return value else: return [value]
[docs]def unsqueeze_like(tensor: torch.Tensor, like: torch.Tensor): """ Unsqueeze last dimensions of tensor to match another tensor's number of dimensions. Args: tensor (torch.Tensor): tensor to unsqueeze like (torch.Tensor): tensor whose dimensions to match """ n_unsqueezes = like.ndim - tensor.ndim if n_unsqueezes < 0: raise ValueError(f"tensor.ndim={tensor.ndim} > like.ndim={like.ndim}") elif n_unsqueezes == 0: return tensor else: return tensor[(...,) + (None,) * n_unsqueezes]
[docs]def apply_to_list(obj: Union[List[Any], Any], func: Callable) -> Union[List[Any], Any]: """ Apply function to a list of objects or directly if passed value is not a list. This is useful if the passed object could be either a list to whose elements a function needs to be applied or just an object to whicht to apply the function. Args: obj (Union[List[Any], Any]): list/tuple on whose elements to apply function, otherwise object to whom to apply function func (Callable): function to apply Returns: Union[List[Any], Any]: list of objects or object depending on function output and if input ``obj`` is of type list/tuple """ if isinstance(obj, (list, tuple)) and not isinstance(obj, rnn.PackedSequence): return [func(o) for o in obj] else: return func(obj)
[docs]class OutputMixIn: """ MixIn to give namedtuple some access capabilities of a dictionary """ def __getitem__(self, k): if isinstance(k, str): return getattr(self, k) else: return super().__getitem__(k) def get(self, k, default=None): return getattr(self, k, default) def items(self): return zip(self._fields, self) def keys(self): return self._fields
[docs] def iget(self, idx: Union[int, slice]): """Select item(s) row-wise. Args: idx ([int, slice]): item to select Returns: Output of single item. """ return self.__class__(*(x[idx] for x in self))
[docs]class TupleOutputMixIn: """MixIn to give output a namedtuple-like access capabilities with ``to_network_output() function``."""
[docs] def to_network_output(self, **results): """ Convert output into a named (and immuatable) tuple. This allows tracing the modules as graphs and prevents modifying the output. Returns: named tuple """ if hasattr(self, "_output_class"): Output = self._output_class else: OutputTuple = namedtuple("output", results) class Output(OutputMixIn, OutputTuple): pass self._output_class = Output return self._output_class(**results)
[docs]def move_to_device( x: Union[ Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]], torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], ], device: Union[str, torch.DeviceObjType], ) -> Union[ Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]], torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], ]: """ Move object to device. Args: x (dictionary of list of tensors): object (e.g. dictionary) of tensors to move to device device (Union[str, torch.DeviceObjType]): device, e.g. "cpu" Returns: x on targeted device """ if isinstance(device, str): device = torch.device(device) if isinstance(x, dict): for name in x.keys(): x[name] = move_to_device(x[name], device=device) elif isinstance(x, OutputMixIn): for xi in x: move_to_device(xi, device=device) return x elif isinstance(x, torch.Tensor) and x.device != device: x = x.to(device) elif isinstance(x, (list, tuple)) and x[0].device != device: x = [move_to_device(xi, device=device) for xi in x] return x
[docs]def detach( x: Union[ Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]], torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], ], ) -> Union[ Dict[str, Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]], torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor], ]: """ Detach object Args: x: object to detach Returns: detached object """ if isinstance(x, torch.Tensor): return x.detach() elif isinstance(x, dict): return {name: detach(xi) for name, xi in x.items()} elif isinstance(x, OutputMixIn): return x.__class__(**{name: detach(xi) for name, xi in x.items()}) elif isinstance(x, (list, tuple)): return [detach(xi) for xi in x] else: return x
[docs]def masked_op(tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None) -> torch.Tensor: """Calculate operation on masked tensor. Args: tensor (torch.Tensor): tensor to conduct operation over op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean". dim (int, optional): dimension to average over. Defaults to 0. mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore). Masks nan values by default. Returns: torch.Tensor: tensor with averaged out dimension """ if mask is None: mask = ~torch.isnan(tensor) masked = tensor.masked_fill(~mask, 0.0) summed = masked.sum(dim=dim) if op == "mean": return summed / mask.sum(dim=dim) # Find the average elif op == "sum": return summed else: raise ValueError(f"unkown operation {op}")
[docs]def repr_class( obj, attributes: Union[List[str], Dict[str, Any]], max_characters_before_break: int = 100, extra_attributes: Dict[str, Any] = {}, ) -> str: """Print class name and parameters. Args: obj: class to format attributes (Union[List[str], Dict[str]]): list of attributes to show or dictionary of attributes and values to show max_characters_before_break (int): number of characters before breaking the into multiple lines extra_attributes (Dict[str, Any]): extra attributes to show in angled brackets Returns: str """ # get attributes if isinstance(attributes, (tuple, list)): attributes = {name: getattr(obj, name) for name in attributes if hasattr(obj, name)} attributes_strings = [f"{name}={repr(value)}" for name, value in attributes.items()] # get header header_name = obj.__class__.__name__ # add extra attributes if len(extra_attributes) > 0: extra_attributes_strings = [f"{name}={repr(value)}" for name, value in extra_attributes.items()] if len(header_name) + 2 + len(", ".join(extra_attributes_strings)) > max_characters_before_break: header = f"{header_name}[\n\t" + ",\n\t".join(attributes_strings) + "\n](" else: header = f"{header_name}[{', '.join(extra_attributes_strings)}](" else: header = f"{header_name}(" # create final representation attributes_string = ", ".join(attributes_strings) if len(attributes_string) + len(header.split("\n")[-1]) + 1 > max_characters_before_break: attributes_string = "\n\t" + ",\n\t".join(attributes_strings) + "\n" return f"{header}{attributes_string})"
[docs]class InitialParameterRepresenterMixIn: def __repr__(self) -> str: if isinstance(self, nn.Module): return super().__repr__() else: attributes = list(inspect.signature(self.__class__).parameters.keys()) return repr_class(self, attributes=attributes) def extra_repr(self) -> str: if isinstance(self, pl.LightningModule): return "\t" + repr(self.hparams).replace("\n", "\n\t") else: attributes = list(inspect.signature(self.__class__).parameters.keys()) return ", ".join([f"{name}={repr(getattr(self, name))}" for name in attributes if hasattr(self, name)])