"""
Helper functions for PyTorch forecasting
"""
from collections import namedtuple
from contextlib import redirect_stdout
import inspect
import os
from typing import Any, Callable, Union
import lightning.pytorch as pl
import torch
from torch import nn
from torch.fft import irfft, rfft
import torch.nn.functional as F
from torch.nn.utils import rnn
[docs]
def integer_histogram(
data: torch.LongTensor, min: Union[None, int] = None, max: Union[None, int] = None
) -> torch.Tensor:
"""
Create histogram of integers in predefined range
Args:
data: data for which to create histogram
min: minimum of histogram, is inferred from data by default
max: maximum of histogram, is inferred from data by default
Returns:
histogram
"""
uniques, counts = torch.unique(data, return_counts=True)
if min is None:
min = uniques.min()
if max is None:
max = uniques.max()
hist = torch.zeros(max - min + 1, dtype=torch.long, device=data.device).scatter(
dim=0, index=uniques - min, src=counts
)
return hist
[docs]
def groupby_apply(
keys: torch.Tensor,
values: torch.Tensor,
bins: int = 95,
reduction: str = "mean",
return_histogram: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
Groupby apply for torch tensors
Args:
keys: tensor of groups (``0`` to ``bins``)
values: values to aggregate - same size as keys
bins: total number of groups
reduction: either "mean" or "sum"
return_histogram: if to return histogram on top
Returns:
tensor of size ``bins`` with aggregated values
and optionally with counts of values
"""
if reduction == "mean":
reduce = torch.mean
elif reduction == "sum":
reduce = torch.sum
else:
raise ValueError(f"Unknown reduction '{reduction}'")
uniques, counts = keys.unique(return_counts=True)
groups = torch.stack(
[reduce(item) for item in torch.split_with_sizes(values, tuple(counts))]
)
reduced = torch.zeros(bins, dtype=values.dtype, device=values.device).scatter(
dim=0, index=uniques, src=groups
)
if return_histogram:
hist = torch.zeros(bins, dtype=torch.long, device=values.device).scatter(
dim=0, index=uniques, src=counts
)
return reduced, hist
else:
return reduced
[docs]
def profile(
function: Callable, profile_fname: str, filter: str = "", period=0.0001, **kwargs
):
"""
Profile a given function with ``vmprof``.
Args:
function (Callable): function to profile
profile_fname (str): path where to save profile (`.txt` file will be saved with line profile)
filter (str, optional): filter name (e.g. module name) to filter profile. Defaults to "".
period (float, optional): frequency of calling profiler in seconds. Defaults to 0.0001.
""" # noqa : E501
import vmprof
from vmprof.show import LinesPrinter
# profiler config
with open(profile_fname, "wb+") as fd:
# start profiler
vmprof.enable(fd.fileno(), lines=True, period=period)
# run function
function(**kwargs)
# stop profiler
vmprof.disable()
# write report to disk
if kwargs.get("lines", True):
with open(f"{os.path.splitext(profile_fname)[0]}.txt", "w") as f:
with redirect_stdout(f):
LinesPrinter(filter=filter).show(profile_fname)
[docs]
def get_embedding_size(n: int, max_size: int = 100) -> int:
"""
Determine empirically good embedding sizes (formula taken from fastai).
Args:
n (int): number of classes
max_size (int, optional): maximum embedding size. Defaults to 100.
Returns:
int: embedding size
"""
if n > 2:
return min(round(1.6 * n**0.56), max_size)
else:
return 1
[docs]
def create_mask(
size: int, lengths: torch.LongTensor, inverse: bool = False
) -> torch.BoolTensor:
"""
Create boolean masks of shape len(lenghts) x size.
An entry at (i, j) is True if lengths[i] > j.
Args:
size (int): size of second dimension
lengths (torch.LongTensor): tensor of lengths
inverse (bool, optional): If true, boolean mask is inverted. Defaults to False.
Returns:
torch.BoolTensor: mask
"""
if inverse: # return where values are
return torch.arange(size, device=lengths.device).unsqueeze(
0
) < lengths.unsqueeze(-1)
else: # return where no values are
return torch.arange(size, device=lengths.device).unsqueeze(
0
) >= lengths.unsqueeze(-1)
_NEXT_FAST_LEN = {}
[docs]
def next_fast_len(size):
"""
Returns the next largest number ``n >= size`` whose prime factors are all
2, 3, or 5. These sizes are efficient for fast fourier transforms.
Equivalent to :func:`scipy.fftpack.next_fast_len`.
Implementation from pyro
:param int size: A positive number.
:returns: A possibly larger number.
:rtype int:
"""
try:
return _NEXT_FAST_LEN[size]
except KeyError:
pass
assert isinstance(size, int) and size > 0
next_size = size
while True:
remaining = next_size
for n in (2, 3, 5):
while remaining % n == 0:
remaining //= n
if remaining == 1:
_NEXT_FAST_LEN[size] = next_size
return next_size
next_size += 1
[docs]
def autocorrelation(input, dim=0):
"""
Computes the autocorrelation of samples at dimension ``dim``.
Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation
Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_.
:param torch.Tensor input: the input tensor.
:param int dim: the dimension to calculate autocorrelation.
:returns torch.Tensor: autocorrelation of ``input``.
"""
# Adapted from Stan implementation
# https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
N = input.size(dim)
M = next_fast_len(N)
M2 = 2 * M
# transpose dim with -1 for Fourier transform
input = input.transpose(dim, -1)
# centering and padding x
centered_signal = input - input.mean(dim=-1, keepdim=True)
# Fourier transform
freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
# take square of magnitude of freqvec (or freqvec x freqvec*)
freqvec_gram = freqvec.pow(2).sum(-1)
# inverse Fourier transform
autocorr = irfft(freqvec_gram, n=M2)
# truncate and normalize the result, then transpose back to original shape
autocorr = autocorr[..., :N]
autocorr = autocorr / torch.tensor(
range(N, 0, -1), dtype=input.dtype, device=input.device
)
autocorr = autocorr / autocorr[..., :1]
return autocorr.transpose(dim, -1)
[docs]
def unpack_sequence(
sequence: Union[torch.Tensor, rnn.PackedSequence],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Unpack RNN sequence.
Args:
sequence (Union[torch.Tensor, rnn.PackedSequence]): RNN packed sequence or tensor of which
first index are samples and second are timesteps
Returns:
Tuple[torch.Tensor, torch.Tensor]: tuple of unpacked sequence and length of samples
""" # noqa : E501
if isinstance(sequence, rnn.PackedSequence):
sequence, lengths = rnn.pad_packed_sequence(sequence, batch_first=True)
# batch sizes reside on the CPU by default -> we need to bring them to GPU
lengths = lengths.to(sequence.device)
else:
lengths = torch.ones(
sequence.size(0), device=sequence.device, dtype=torch.long
) * sequence.size(1)
return sequence, lengths
[docs]
def concat_sequences(
sequences: Union[list[torch.Tensor], list[rnn.PackedSequence]],
) -> Union[torch.Tensor, rnn.PackedSequence]:
"""
Concatenate RNN sequences.
Args:
sequences (Union[List[torch.Tensor], List[rnn.PackedSequence]): list of RNN packed sequences or tensors of which
first index are samples and second are timesteps
Returns:
Union[torch.Tensor, rnn.PackedSequence]: concatenated sequence
""" # noqa : E501
if isinstance(sequences[0], rnn.PackedSequence):
return rnn.pack_sequence(sequences, enforce_sorted=False)
elif isinstance(sequences[0], torch.Tensor):
return torch.cat(sequences, dim=0)
elif isinstance(sequences[0], (tuple, list)):
return tuple(
concat_sequences([sequences[ii][i] for ii in range(len(sequences))])
for i in range(len(sequences[0]))
)
else:
raise ValueError("Unsupported sequence type")
[docs]
def padded_stack(
tensors: list[torch.Tensor],
side: str = "right",
mode: str = "constant",
value: Union[int, float] = 0,
) -> torch.Tensor:
"""
Stack tensors along first dimension and pad them along last dimension to ensure their size is equal.
Args:
tensors (List[torch.Tensor]): list of tensors to stack
side (str): side on which to pad - "left" or "right". Defaults to "right".
mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant'
value (Union[int, float]): value to use for constant padding
Returns:
torch.Tensor: stacked tensor
""" # noqa : E501
full_size = max([x.size(-1) for x in tensors])
def make_padding(pad):
if side == "left":
return (pad, 0)
elif side == "right":
return (0, pad)
else:
raise ValueError(f"side for padding '{side}' is unknown")
out = torch.stack(
[
(
F.pad(x, make_padding(full_size - x.size(-1)), mode=mode, value=value)
if full_size - x.size(-1) > 0
else x
)
for x in tensors
],
dim=0,
)
return out
[docs]
def to_list(value: Any) -> list[Any]:
"""
Convert value or list to list of values.
If already list, return object directly
Args:
value (Any): value to convert
Returns:
List[Any]: list of values
"""
if isinstance(value, (tuple, list)) and not isinstance(value, rnn.PackedSequence):
return value
else:
return [value]
[docs]
def unsqueeze_like(tensor: torch.Tensor, like: torch.Tensor):
"""
Unsqueeze last dimensions of tensor to match another tensor's number of dimensions.
Args:
tensor (torch.Tensor): tensor to unsqueeze
like (torch.Tensor): tensor whose dimensions to match
"""
n_unsqueezes = like.ndim - tensor.ndim
if n_unsqueezes < 0:
raise ValueError(f"tensor.ndim={tensor.ndim} > like.ndim={like.ndim}")
elif n_unsqueezes == 0:
return tensor
else:
return tensor[(...,) + (None,) * n_unsqueezes]
[docs]
def apply_to_list(obj: Union[list[Any], Any], func: Callable) -> Union[list[Any], Any]:
"""
Apply function to a list of objects or directly if passed value is not a list.
This is useful if the passed object could be either a list to whose elements
a function needs to be applied or just an object to whicht to apply the function.
Args:
obj (Union[List[Any], Any]): list/tuple on whose elements to apply function,
otherwise object to whom to apply function
func (Callable): function to apply
Returns:
Union[List[Any], Any]: list of objects or object depending on function output
and if input ``obj`` is of type list/tuple
"""
if isinstance(obj, (list, tuple)) and not isinstance(obj, rnn.PackedSequence):
return [func(o) for o in obj]
else:
return func(obj)
[docs]
class OutputMixIn:
"""
MixIn to give namedtuple some access capabilities of a dictionary
"""
def __getitem__(self, k):
if isinstance(k, str):
return getattr(self, k)
else:
return super().__getitem__(k)
def get(self, k, default=None):
return getattr(self, k, default)
def items(self):
return zip(self._fields, self)
def keys(self):
return self._fields
[docs]
def iget(self, idx: Union[int, slice]):
"""Select item(s) row-wise.
Args:
idx ([int, slice]): item to select
Returns:
Output of single item.
"""
return self.__class__(*(x[idx] for x in self))
[docs]
class TupleOutputMixIn:
"""MixIn to give output a namedtuple-like access capabilitieswith ``to_network_output() function``.""" # noqa : E501
[docs]
def to_network_output(self, **results):
"""
Convert output into a named (and immuatable) tuple.
This allows tracing the modules as graphs and prevents modifying the output.
Returns:
named tuple
"""
if hasattr(self, "_output_class"):
Output = self._output_class
else:
OutputTuple = namedtuple("output", results)
class Output(OutputMixIn, OutputTuple):
pass
self._output_class = Output
return self._output_class(**results)
[docs]
def move_to_device(
x: Union[
dict[str, Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]],
torch.Tensor,
list[torch.Tensor],
tuple[torch.Tensor],
],
device: Union[str, torch.DeviceObjType],
) -> Union[
dict[str, Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]],
torch.Tensor,
list[torch.Tensor],
tuple[torch.Tensor],
]:
"""
Move object to device.
Args:
x (dictionary of list of tensors): object (e.g. dictionary) of tensors to move to device
device (Union[str, torch.DeviceObjType]): device, e.g. "cpu"
Returns:
x on targeted device
""" # noqa: E501
if isinstance(device, str):
if device == "mps":
if hasattr(torch.backends, device):
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
device = torch.device(device)
if isinstance(x, dict):
for name in x.keys():
x[name] = move_to_device(x[name], device=device)
elif isinstance(x, OutputMixIn):
for xi in x:
move_to_device(xi, device=device)
return x
elif isinstance(x, torch.Tensor) and x.device != device:
x = x.to(device)
elif isinstance(x, (list, tuple)) and x[0].device != device:
x = [move_to_device(xi, device=device) for xi in x]
return x
[docs]
def detach(
x: Union[
dict[str, Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]],
torch.Tensor,
list[torch.Tensor],
tuple[torch.Tensor],
],
) -> Union[
dict[str, Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]],
torch.Tensor,
list[torch.Tensor],
tuple[torch.Tensor],
]:
"""
Detach object
Args:
x: object to detach
Returns:
detached object
"""
if isinstance(x, torch.Tensor):
return x.detach()
elif isinstance(x, dict):
return {name: detach(xi) for name, xi in x.items()}
elif isinstance(x, OutputMixIn):
return x.__class__(**{name: detach(xi) for name, xi in x.items()})
elif isinstance(x, (list, tuple)):
return [detach(xi) for xi in x]
else:
return x
[docs]
def masked_op(
tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None
) -> torch.Tensor:
"""Calculate operation on masked tensor.
Args:
tensor (torch.Tensor): tensor to conduct operation over
op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean".
dim (int, optional): dimension to average over. Defaults to 0.
mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore).
Masks nan values by default.
Returns:
torch.Tensor: tensor with averaged out dimension
""" # noqa : E501
if mask is None:
mask = ~torch.isnan(tensor)
masked = tensor.masked_fill(~mask, 0.0)
summed = masked.sum(dim=dim)
if op == "mean":
return summed / mask.sum(dim=dim) # Find the average
elif op == "sum":
return summed
else:
raise ValueError(f"unkown operation {op}")
[docs]
def repr_class(
obj,
attributes: Union[list[str], dict[str, Any]],
max_characters_before_break: int = 100,
extra_attributes: dict[str, Any] = None,
) -> str:
"""Print class name and parameters.
Args:
obj: class to format
attributes (Union[List[str], Dict[str]]): list of attributes to show or dictionary of attributes and values
to show max_characters_before_break (int): number of characters before breaking the into multiple lines
extra_attributes (Dict[str, Any]): extra attributes to show in angled brackets
Returns:
str
""" # noqa E501
if extra_attributes is None:
extra_attributes = {}
# get attributes
if isinstance(attributes, (tuple, list)):
attributes = {
name: getattr(obj, name) for name in attributes if hasattr(obj, name)
}
attributes_strings = [f"{name}={repr(value)}" for name, value in attributes.items()]
# get header
header_name = obj.__class__.__name__
# add extra attributes
if len(extra_attributes) > 0:
extra_attributes_strings = [
f"{name}={repr(value)}" for name, value in extra_attributes.items()
]
if (
len(header_name) + 2 + len(", ".join(extra_attributes_strings))
> max_characters_before_break
):
header = f"{header_name}[\n\t" + ",\n\t".join(attributes_strings) + "\n]("
else:
header = f"{header_name}[{', '.join(extra_attributes_strings)}]("
else:
header = f"{header_name}("
# create final representation
attributes_string = ", ".join(attributes_strings)
if (
len(attributes_string) + len(header.split("\n")[-1]) + 1
> max_characters_before_break
):
attributes_string = "\n\t" + ",\n\t".join(attributes_strings) + "\n"
return f"{header}{attributes_string})"
[docs]
class InitialParameterRepresenterMixIn:
def __repr__(self) -> str:
if isinstance(self, nn.Module):
return super().__repr__()
else:
attributes = list(inspect.signature(self.__class__).parameters.keys())
return repr_class(self, attributes=attributes)
def extra_repr(self) -> str:
if isinstance(self, pl.LightningModule):
return "\t" + repr(self.hparams).replace("\n", "\n\t")
else:
attributes = list(inspect.signature(self.__class__).parameters.keys())
return ", ".join(
[
f"{name}={repr(getattr(self, name))}"
for name in attributes
if hasattr(self, name)
]
)