"""
Implementation of `nn.Modules` for TimeXer model.
"""
import math
from math import sqrt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class TriangularCausalMask:
"""
Triangular causal mask for attention mechanism.
"""
def __init__(self, B, L, device="cpu"):
mask_shape = [B, 1, L, L]
with torch.no_grad():
self._mask = torch.triu(
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
).to(device)
@property
def mask(self):
return self._mask
[docs]
class FullAttention(nn.Module):
"""
Full attention mechanism with optional masking and dropout.
Args:
mask_flag (bool): Whether to apply masking.
factor (int): Factor for scaling the attention scores.
scale (float): Scaling factor for attention scores.
attention_dropout (float): Dropout rate for attention scores.
output_attention (bool): Whether to output attention weights.
use_efficient_attention (bool): Whether to use PyTorch's native,
optimized Scaled Dot Product Attention implementation which can
reduce computation time and memory consumption for longer sequences.
PyTorch automatically selects the optimal backend (FlashAttention-2,
Memory-Efficient Attention, or their own C++ implementation) based
on user's input properties, hardware capabilities, and build
configuration.
"""
def __init__(
self,
mask_flag=True,
factor=5,
scale=None,
attention_dropout=0.1,
output_attention=False,
use_efficient_attention=False,
):
super().__init__()
if output_attention and use_efficient_attention:
raise ValueError(
"Cannot output attention scores using efficient attention. "
"Set `use_efficient_attention=False` or "
"`output_attention=False`."
)
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.use_efficient_attention = use_efficient_attention
self.dropout = nn.Dropout(attention_dropout)
[docs]
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
if self.use_efficient_attention:
V, A = self._efficient_attention(queries, keys, values, attn_mask)
else:
V, A = self._einsum_attention(queries, keys, values, attn_mask)
if self.output_attention:
return V.contiguous(), A
else:
return V.contiguous(), None
def _einsum_attention(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1.0 / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.abs)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
return V, A
def _efficient_attention(self, queries, keys, values, attn_mask):
# SDPA expects [B, H, L, E] shape
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
V = nn.functional.scaled_dot_product_attention(
query=queries,
key=keys,
value=values,
attn_mask=attn_mask.mask if attn_mask is not None else None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=self.mask_flag if attn_mask is None else False,
scale=self.scale, # if == None, PyTorch computes internally
)
V = V.transpose(1, 2)
return V, None
[docs]
class AttentionLayer(nn.Module):
"""
Attention layer that combines query, key, and value projections with an attention
mechanism.
Args:
attention (nn.Module): Attention mechanism to use.
d_model (int): Dimension of the model.
n_heads (int): Number of attention heads.
d_keys (int, optional): Dimension of the keys. Defaults to d_model // n_heads.
d_values (int, optional):
Dimension of the values. Defaults to d_model // n_heads.
"""
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super().__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
[docs]
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
if S == 0:
# skip the cross attention process since there is no exogenous variables
queries = self.query_projection(queries)
return self.out_projection(queries), None
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries, keys, values, attn_mask, tau=tau, delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn
[docs]
class DataEmbedding_inverted(nn.Module):
"""
Data embedding module for time series data.
Args:
c_in (int): Number of input features.
d_model (int): Dimension of the model.
embed_type (str): Type of embedding to use. Defaults to "fixed".
freq (str): Frequency of the time series data. Defaults to "h".
dropout (float): Dropout rate. Defaults to 0.1.
"""
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super().__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.dropout = nn.Dropout(p=dropout)
[docs]
def forward(self, x, x_mark):
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
# x: [Batch Variate d_model]
return self.dropout(x)
[docs]
class PositionalEmbedding(nn.Module):
"""
Positional embedding module for time series data.
Args:
d_model (int): Dimension of the model.
max_len (int): Maximum length of the input sequence. Defaults to 5000."""
def __init__(self, d_model, max_len=5000):
super().__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.requires_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
[docs]
def forward(self, x):
return self.pe[:, : x.size(1)]
[docs]
class FlattenHead(nn.Module):
"""
Flatten head for the output of the model.
Args:
n_vars (int): Number of input features.
nf (int): Number of features in the last layer.
target_window (int): Target window size.
head_dropout (float): Dropout rate for the head. Defaults to 0.
n_quantiles (int, optional): Number of quantiles. Defaults to 1."""
def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=1):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.n_quantiles = n_quantiles
self.linear = nn.Linear(nf, target_window * n_quantiles)
self.dropout = nn.Dropout(head_dropout)
[docs]
def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
return x
[docs]
class EnEmbedding(nn.Module):
"""
Encoder embedding module for time series data. Handles endogenous feature
embeddings in this case.
Args:
n_vars (int): Number of input features.
d_model (int): Dimension of the model.
patch_len (int): Length of the patches.
dropout (float): Dropout rate. Defaults to 0.1."""
def __init__(self, n_vars, d_model, patch_len, dropout):
super().__init__()
self.patch_len = patch_len
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
self.position_embedding = PositionalEmbedding(d_model)
self.dropout = nn.Dropout(dropout)
[docs]
def forward(self, x):
n_vars = x.shape[1]
glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
x = torch.cat([x, glb], dim=2)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
return self.dropout(x), n_vars
[docs]
class Encoder(nn.Module):
"""
Encoder module for the TimeXer model.
Args:
layers (list): List of encoder layers.
norm_layer (nn.Module, optional): Normalization layer. Defaults to None.
projection (nn.Module, optional): Projection layer. Defaults to None.
"""
def __init__(self, layers, norm_layer=None, projection=None):
super().__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
[docs]
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
for layer in self.layers:
x = layer(
x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
)
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x
[docs]
class EncoderLayer(nn.Module):
"""
Encoder layer for the TimeXer model.
Args:
self_attention (nn.Module): Self-attention mechanism.
cross_attention (nn.Module): Cross-attention mechanism.
d_model (int): Dimension of the model.
d_ff (int, optional):
Dimension of the feedforward layer. Defaults to 4 * d_model.
dropout (float): Dropout rate. Defaults to 0.1.
activation (str): Activation function. Defaults to "relu".
"""
def __init__(
self,
self_attention,
cross_attention,
d_model,
d_ff=None,
dropout=0.1,
activation="relu",
):
super().__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
[docs]
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
B, L, D = cross.shape
x = x + self.dropout(
self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
)
x = self.norm1(x)
x_glb_ori = x[:, -1, :].unsqueeze(1)
x_glb = torch.reshape(x_glb_ori, (B, -1, D))
x_glb_attn = self.dropout(
self.cross_attention(
x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
)[0]
)
x_glb_attn = torch.reshape(
x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])
).unsqueeze(1)
x_glb = x_glb_ori + x_glb_attn
x_glb = self.norm2(x_glb)
y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)