"""Classes and functions for the MQF2 metric."""
from typing import Optional
from skbase.utils.dependencies import _safe_import
import torch
from torch.distributions import (
AffineTransform,
Distribution,
Normal,
TransformedDistribution,
)
import torch.nn.functional as F
DeepConvexFlow = _safe_import("cpflows.flows.DeepConvexFlow")
SequentialFlow = _safe_import("cpflows.flows.SequentialFlow")
[docs]
class DeepConvexNet(DeepConvexFlow):
r"""
Class that takes a partially input convex neural network (picnn)
as input and equips it with functions of logdet
computation (both estimation and exact computation).
This class is based on DeepConvexFlow of the CP-Flow
repo (https://github.com/CW-Huang/CP-Flow)
For details of the logdet estimator, see
``Convex potential flows: Universal probability distributions
with optimal transport and convex optimization``
Parameters
----------
picnn
A partially input convex neural network (picnn)
dim
Dimension of the input
is_energy_score
Indicates if energy score is used as the objective function
If yes, the network is not required to be strictly convex,
so we can just use the picnn
otherwise, a quadratic term is added to the output of picnn
to render it strictly convex
m1
Dimension of the Krylov subspace of the Lanczos tridiagonalization
used in approximating H of logdet(H)
m2
Iteration number of the conjugate gradient algorithm
used to approximate logdet(H)
rtol
relative tolerance of the conjugate gradient algorithm
atol
absolute tolerance of the conjugate gradient algorithm
"""
def __init__(
self,
picnn: torch.nn.Module,
dim: int,
is_energy_score: bool = False,
estimate_logdet: bool = False,
m1: int = 10,
m2: int | None = None,
rtol: float = 0.0,
atol: float = 1e-3,
) -> None:
super().__init__(
picnn,
dim,
m1=m1,
m2=m2,
rtol=rtol,
atol=atol,
)
self.picnn = self.icnn
self.is_energy_score = is_energy_score
self.estimate_logdet = estimate_logdet
[docs]
def get_potential(
self, x: torch.Tensor, context: torch.Tensor | None = None
) -> torch.Tensor:
n = x.size(0)
output = self.picnn(x, context)
if self.is_energy_score:
return output
else:
return (
F.softplus(self.w1) * output
+ F.softplus(self.w0) * (x.view(n, -1) ** 2).sum(1, keepdim=True) / 2
)
[docs]
class SequentialNet(SequentialFlow):
r"""
Class that combines a list of DeepConvexNet and ActNorm
layers and provides energy score computation
This class is based on SequentialFlow of the CP-Flow repo
(https://github.com/CW-Huang/CP-Flow)
Parameters
----------
networks
list of DeepConvexNet and/or ActNorm instances
"""
def __init__(self, networks: list[torch.nn.Module]) -> None:
super().__init__(networks)
self.networks = self.flows
[docs]
def forward(
self, x: torch.Tensor, context: torch.Tensor | None = None
) -> torch.Tensor:
for network in self.networks:
if isinstance(network, DeepConvexNet):
x = network.forward(x, context=context)
else:
x = network.forward(x)
return x
[docs]
def es_sample(self, hidden_state: torch.Tensor, dimension: int) -> torch.Tensor:
"""
Auxiliary function for energy score computation
Drawing samples conditioned on the hidden state
Parameters
----------
hidden_state
hidden_state which the samples conditioned
on (num_samples, hidden_size)
dimension
dimension of the input
Returns
-------
samples
samples drawn (num_samples, dimension)
"""
num_samples = hidden_state.shape[0]
zero = torch.tensor(0, dtype=hidden_state.dtype, device=hidden_state.device)
one = torch.ones_like(zero)
standard_normal = Normal(zero, one)
samples = self.forward(
standard_normal.sample([num_samples * dimension]).view(
num_samples, dimension
),
context=hidden_state,
)
return samples
[docs]
def energy_score(
self,
z: torch.Tensor,
hidden_state: torch.Tensor,
es_num_samples: int = 50,
beta: float = 1.0,
) -> torch.Tensor:
"""
Computes the (approximated) energy score sum_i ES(g,z_i),
where ES(g,z_i) =
-1/(2*es_num_samples^2) * sum_{w,w'} ||w-w'||_2^beta
+ 1/es_num_samples * sum_{w''} ||w''-z_i||_2^beta,
w's are samples drawn from the
quantile function g(., h_i) (gradient of picnn),
h_i is the hidden state associated with z_i,
and es_num_samples is the number of samples drawn
for each of w, w', w'' in energy score approximation
Parameters
----------
z
Observations (numel_batch, dimension)
hidden_state
Hidden state (numel_batch, hidden_size)
es_num_samples
Number of samples drawn for each of w, w', w''
in energy score approximation
beta
Hyperparameter of the energy score, see the formula above
Returns
-------
loss
energy score (numel_batch)
"""
numel_batch, dimension = z.shape[0], z.shape[1]
# (numel_batch * dimension * es_num_samples x hidden_size)
hidden_state_repeat = hidden_state.repeat_interleave(
repeats=es_num_samples, dim=0
)
w = self.es_sample(hidden_state_repeat, dimension)
w_prime = self.es_sample(hidden_state_repeat, dimension)
first_term = (
torch.norm(
w.view(numel_batch, 1, es_num_samples, dimension)
- w_prime.view(numel_batch, es_num_samples, 1, dimension),
dim=-1,
)
** beta
)
mean_first_term = torch.mean(first_term.view(numel_batch, -1), dim=-1)
# since both tensors are huge (numel_batch*es_num_samples, dimension),
# delete to free up GPU memories
del w, w_prime
z_repeat = z.repeat_interleave(repeats=es_num_samples, dim=0)
w_bar = self.es_sample(hidden_state_repeat, dimension)
second_term = (
torch.norm(
w_bar.view(numel_batch, es_num_samples, dimension)
- z_repeat.view(numel_batch, es_num_samples, dimension),
dim=-1,
)
** beta
)
mean_second_term = torch.mean(second_term.view(numel_batch, -1), dim=-1)
loss = -0.5 * mean_first_term + mean_second_term
return loss
[docs]
class MQF2Distribution(Distribution):
r"""
Distribution class for the model MQF2 proposed in the paper
``Multivariate Quantile Function Forecaster``
by Kan, Aubet, Januschowski, Park, Benidis, Ruthotto, Gasthaus
Parameters
----------
picnn
A SequentialNet instance of a
partially input convex neural network (picnn)
hidden_state
hidden_state obtained by unrolling the RNN encoder
shape = (batch_size, context_length, hidden_size) in training
shape = (batch_size, hidden_size) in inference
prediction_length
Length of the prediction horizon
is_energy_score
If True, use energy score as objective function
otherwise use maximum likelihood as
objective function (normalizing flows)
es_num_samples
Number of samples drawn to approximate the energy score
beta
Hyperparameter of the energy score (power of the two terms)
threshold_input
Clamping threshold of the (scaled) input when maximum
likelihood is used as objective function
this is used to make the forecaster more robust
to outliers in training samples
validate_args
Sets whether validation is enabled or disabled
For more details, refer to the descriptions in
torch.distributions.distribution.Distribution
"""
def __init__(
self,
picnn: torch.nn.Module,
hidden_state: torch.Tensor,
prediction_length: int,
is_energy_score: bool = True,
es_num_samples: int = 50,
beta: float = 1.0,
threshold_input: float = 100.0,
validate_args: bool = False,
) -> None:
self.picnn = picnn
self.hidden_state = hidden_state
self.prediction_length = prediction_length
self.is_energy_score = is_energy_score
self.es_num_samples = es_num_samples
self.beta = beta
self.threshold_input = threshold_input
super().__init__(batch_shape=self.batch_shape, validate_args=validate_args)
self.context_length = (
self.hidden_state.shape[-2] if len(self.hidden_state.shape) > 2 else 1
)
self.numel_batch = self.get_numel(self.batch_shape)
# mean zero and std one
mu = torch.tensor(0, dtype=hidden_state.dtype, device=hidden_state.device)
sigma = torch.ones_like(mu)
self.standard_normal = Normal(mu, sigma)
[docs]
def stack_sliding_view(self, z: torch.Tensor) -> torch.Tensor:
"""
Auxiliary function for loss computation
Unfolds the observations by sliding a window of size prediction_length
over the observations z
Then, reshapes the observations into a 2-dimensional tensor for
further computation
Parameters
----------
z
A batch of time series with shape
(batch_size, context_length + prediction_length - 1)
Returns
-------
Tensor
Unfolded time series with shape
(batch_size * context_length, prediction_length)
"""
z = z.unfold(dimension=-1, size=self.prediction_length, step=1)
z = z.reshape(-1, z.shape[-1])
return z
def loss(self, z: torch.Tensor) -> torch.Tensor:
if self.is_energy_score:
return self.energy_score(z)
else:
return -self.log_prob(z)
[docs]
def log_prob(self, z: torch.Tensor) -> torch.Tensor:
"""
Computes the log likelihood log(g(z)) + logdet(dg(z)/dz),
where g is the gradient of the picnn
Parameters
----------
z
A batch of time series with shape
(batch_size, context_length + prediction_length - 1)
Returns
-------
loss
Tesnor of shape (batch_size * context_length,)
"""
z = torch.clamp(z, min=-self.threshold_input, max=self.threshold_input)
z = self.stack_sliding_view(z)
loss = self.picnn.logp(
z, self.hidden_state.reshape(-1, self.hidden_state.shape[-1])
)
return loss
[docs]
def energy_score(self, z: torch.Tensor) -> torch.Tensor:
"""
Computes the (approximated) energy score sum_i ES(g,z_i),
where ES(g,z_i) =
-1/(2*es_num_samples^2) * sum_{w,w'} ||w-w'||_2^beta
+ 1/es_num_samples * sum_{w''} ||w''-z_i||_2^beta,
w's are samples drawn from the
quantile function g(., h_i) (gradient of picnn),
h_i is the hidden state associated with z_i,
and es_num_samples is the number of samples drawn
for each of w, w', w'' in energy score approximation
Parameters
----------
z
A batch of time series with shape
(batch_size, context_length + prediction_length - 1)
Returns
-------
loss
Tensor of shape (batch_size * context_length,)
"""
es_num_samples = self.es_num_samples
beta = self.beta
z = self.stack_sliding_view(z)
reshaped_hidden_state = self.hidden_state.reshape(
-1, self.hidden_state.shape[-1]
)
loss = self.picnn.energy_score(
z, reshaped_hidden_state, es_num_samples=es_num_samples, beta=beta
)
return loss
[docs]
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
Generates the sample paths
Parameters
----------
sample_shape
Shape of the samples
Returns
-------
sample_paths
Tesnor of shape (batch_size, * sample_shape, prediction_length)
"""
numel_batch = self.numel_batch
prediction_length = self.prediction_length
num_samples_per_batch = MQF2Distribution.get_numel(sample_shape)
num_samples = num_samples_per_batch * numel_batch
hidden_state_repeat = self.hidden_state.repeat_interleave(
repeats=num_samples_per_batch, dim=0
)
alpha = torch.rand(
(num_samples, prediction_length),
dtype=self.hidden_state.dtype,
device=self.hidden_state.device,
layout=self.hidden_state.layout,
).clamp(
min=1e-4, max=1 - 1e-4
) # prevent numerical issues by preventing to sample beyond 0.1% and 99.9% percentiles # noqa: E501
samples = (
self.quantile(alpha, hidden_state_repeat)
.reshape((numel_batch,) + sample_shape + (prediction_length,))
.transpose(0, 1)
)
return samples
[docs]
def quantile(
self, alpha: torch.Tensor, hidden_state: torch.Tensor | None = None
) -> torch.Tensor:
"""
Generates the predicted paths associated with the quantile levels alpha
Parameters
----------
alpha
quantile levels,
shape = (batch_shape, prediction_length)
hidden_state
hidden_state, shape = (batch_shape, hidden_size)
Returns
-------
results
predicted paths of shape = (batch_shape, prediction_length)
"""
if hidden_state is None:
hidden_state = self.hidden_state
normal_quantile = self.standard_normal.icdf(alpha)
# In the energy score approach, we directly draw samples from picnn
# In the MLE (Normalizing flows) approach, we need to invert the picnn
# (go backward through the flow) to draw samples
if self.is_energy_score:
result = self.picnn(normal_quantile, context=hidden_state)
else:
result = self.picnn.reverse(normal_quantile, context=hidden_state)
return result
@staticmethod
def get_numel(tensor_shape: torch.Size) -> int:
# Auxiliary function
# compute number of elements specified in a torch.Size()
return torch.prod(torch.tensor(tensor_shape)).item()
@property
def batch_shape(self) -> torch.Size:
# last dimension is the hidden state size
return self.hidden_state.shape[:-1]
@property
def event_shape(self) -> tuple:
return (self.prediction_length,)
@property
def event_dim(self) -> int:
return 1