Source code for pytorch_forecasting.models.temporal_fusion_transformer.sub_modules

"""
Implementation of ``nn.Modules`` for temporal fusion transformer.
"""
import math
from typing import Dict, List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class TimeDistributed(nn.Module): def __init__(self, module: nn.Module, batch_first: bool = False): super().__init__() self.module = module self.batch_first = batch_first
[docs] def forward(self, x): if len(x.size()) <= 2: return self.module(x) # Squash samples and timesteps into a single axis x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) y = self.module(x_reshape) # We have to reshape Y if self.batch_first: y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) else: y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) return y
[docs]class TimeDistributedInterpolation(nn.Module): def __init__(self, output_size: int, batch_first: bool = False, trainable: bool = False): super().__init__() self.output_size = output_size self.batch_first = batch_first self.trainable = trainable if self.trainable: self.mask = nn.Parameter(torch.zeros(self.output_size, dtype=torch.float32)) self.gate = nn.Sigmoid() def interpolate(self, x): upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1) if self.trainable: upsampled = upsampled * self.gate(self.mask.unsqueeze(0)) * 2.0 return upsampled
[docs] def forward(self, x): if len(x.size()) <= 2: return self.interpolate(x) # Squash samples and timesteps into a single axis x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) y = self.interpolate(x_reshape) # We have to reshape Y if self.batch_first: y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) else: y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) return y
[docs]class GatedLinearUnit(nn.Module): """Gated Linear Unit""" def __init__(self, input_size: int, hidden_size: int = None, dropout: float = None): super().__init__() if dropout is not None: self.dropout = nn.Dropout(dropout) else: self.dropout = dropout self.hidden_size = hidden_size or input_size self.fc = nn.Linear(input_size, self.hidden_size * 2) self.init_weights() def init_weights(self): for n, p in self.named_parameters(): if "bias" in n: torch.nn.init.zeros_(p) elif "fc" in n: torch.nn.init.xavier_uniform_(p)
[docs] def forward(self, x): if self.dropout is not None: x = self.dropout(x) x = self.fc(x) x = F.glu(x, dim=-1) return x
[docs]class ResampleNorm(nn.Module): def __init__(self, input_size: int, output_size: int = None, trainable_add: bool = True): super().__init__() self.input_size = input_size self.trainable_add = trainable_add self.output_size = output_size or input_size if self.input_size != self.output_size: self.resample = TimeDistributedInterpolation(self.output_size, batch_first=True, trainable=False) if self.trainable_add: self.mask = nn.Parameter(torch.zeros(self.output_size, dtype=torch.float)) self.gate = nn.Sigmoid() self.norm = nn.LayerNorm(self.output_size)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.input_size != self.output_size: x = self.resample(x) if self.trainable_add: x = x * self.gate(self.mask) * 2.0 output = self.norm(x) return output
[docs]class AddNorm(nn.Module): def __init__(self, input_size: int, skip_size: int = None, trainable_add: bool = True): super().__init__() self.input_size = input_size self.trainable_add = trainable_add self.skip_size = skip_size or input_size if self.input_size != self.skip_size: self.resample = TimeDistributedInterpolation(self.input_size, batch_first=True, trainable=False) if self.trainable_add: self.mask = nn.Parameter(torch.zeros(self.input_size, dtype=torch.float)) self.gate = nn.Sigmoid() self.norm = nn.LayerNorm(self.input_size)
[docs] def forward(self, x: torch.Tensor, skip: torch.Tensor): if self.input_size != self.skip_size: skip = self.resample(skip) if self.trainable_add: skip = skip * self.gate(self.mask) * 2.0 output = self.norm(x + skip) return output
[docs]class GateAddNorm(nn.Module): def __init__( self, input_size: int, hidden_size: int = None, skip_size: int = None, trainable_add: bool = False, dropout: float = None, ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size or input_size self.skip_size = skip_size or self.hidden_size self.dropout = dropout self.glu = GatedLinearUnit(self.input_size, hidden_size=self.hidden_size, dropout=self.dropout) self.add_norm = AddNorm(self.hidden_size, skip_size=self.skip_size, trainable_add=trainable_add)
[docs] def forward(self, x, skip): output = self.glu(x) output = self.add_norm(output, skip) return output
[docs]class GatedResidualNetwork(nn.Module): def __init__( self, input_size: int, hidden_size: int, output_size: int, dropout: float = 0.1, context_size: int = None, residual: bool = False, ): super().__init__() self.input_size = input_size self.output_size = output_size self.context_size = context_size self.hidden_size = hidden_size self.dropout = dropout self.residual = residual if self.input_size != self.output_size and not self.residual: residual_size = self.input_size else: residual_size = self.output_size if self.output_size != residual_size: self.resample_norm = ResampleNorm(residual_size, self.output_size) self.fc1 = nn.Linear(self.input_size, self.hidden_size) self.elu = nn.ELU() if self.context_size is not None: self.context = nn.Linear(self.context_size, self.hidden_size, bias=False) self.fc2 = nn.Linear(self.hidden_size, self.hidden_size) self.init_weights() self.gate_norm = GateAddNorm( input_size=self.hidden_size, skip_size=self.output_size, hidden_size=self.output_size, dropout=self.dropout, trainable_add=False, ) def init_weights(self): for name, p in self.named_parameters(): if "bias" in name: torch.nn.init.zeros_(p) elif "fc1" in name or "fc2" in name: torch.nn.init.kaiming_normal_(p, a=0, mode="fan_in", nonlinearity="leaky_relu") elif "context" in name: torch.nn.init.xavier_uniform_(p)
[docs] def forward(self, x, context=None, residual=None): if residual is None: residual = x if self.input_size != self.output_size and not self.residual: residual = self.resample_norm(residual) x = self.fc1(x) if context is not None: context = self.context(context) x = x + context x = self.elu(x) x = self.fc2(x) x = self.gate_norm(x, residual) return x
[docs]class VariableSelectionNetwork(nn.Module): def __init__( self, input_sizes: Dict[str, int], hidden_size: int, input_embedding_flags: Dict[str, bool] = {}, dropout: float = 0.1, context_size: int = None, single_variable_grns: Dict[str, GatedResidualNetwork] = {}, prescalers: Dict[str, nn.Linear] = {}, ): """ Calcualte weights for ``num_inputs`` variables which are each of size ``input_size`` """ super().__init__() self.hidden_size = hidden_size self.input_sizes = input_sizes self.input_embedding_flags = input_embedding_flags self.dropout = dropout self.context_size = context_size if self.num_inputs > 1: if self.context_size is not None: self.flattened_grn = GatedResidualNetwork( self.input_size_total, min(self.hidden_size, self.num_inputs), self.num_inputs, self.dropout, self.context_size, residual=False, ) else: self.flattened_grn = GatedResidualNetwork( self.input_size_total, min(self.hidden_size, self.num_inputs), self.num_inputs, self.dropout, residual=False, ) self.single_variable_grns = nn.ModuleDict() self.prescalers = nn.ModuleDict() for name, input_size in self.input_sizes.items(): if name in single_variable_grns: self.single_variable_grns[name] = single_variable_grns[name] elif self.input_embedding_flags.get(name, False): self.single_variable_grns[name] = ResampleNorm(input_size, self.hidden_size) else: self.single_variable_grns[name] = GatedResidualNetwork( input_size, min(input_size, self.hidden_size), output_size=self.hidden_size, dropout=self.dropout, ) if name in prescalers: # reals need to be first scaled up self.prescalers[name] = prescalers[name] elif not self.input_embedding_flags.get(name, False): self.prescalers[name] = nn.Linear(1, input_size) self.softmax = nn.Softmax(dim=-1) @property def input_size_total(self): return sum(size if name in self.input_embedding_flags else size for name, size in self.input_sizes.items()) @property def num_inputs(self): return len(self.input_sizes)
[docs] def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None): if self.num_inputs > 1: # transform single variables var_outputs = [] weight_inputs = [] for name in self.input_sizes.keys(): # select embedding belonging to a single input variable_embedding = x[name] if name in self.prescalers: variable_embedding = self.prescalers[name](variable_embedding) weight_inputs.append(variable_embedding) var_outputs.append(self.single_variable_grns[name](variable_embedding)) var_outputs = torch.stack(var_outputs, dim=-1) # calculate variable weights flat_embedding = torch.cat(weight_inputs, dim=-1) sparse_weights = self.flattened_grn(flat_embedding, context) sparse_weights = self.softmax(sparse_weights).unsqueeze(-2) outputs = var_outputs * sparse_weights outputs = outputs.sum(dim=-1) else: # for one input, do not perform variable selection but just encoding name = next(iter(self.single_variable_grns.keys())) variable_embedding = x[name] if name in self.prescalers: variable_embedding = self.prescalers[name](variable_embedding) outputs = self.single_variable_grns[name](variable_embedding) # fast forward if only one variable if outputs.ndim == 3: # -> batch size, time, hidden size, n_variables sparse_weights = torch.ones(outputs.size(0), outputs.size(1), 1, 1, device=outputs.device) # else: # ndim == 2 -> batch size, hidden size, n_variables sparse_weights = torch.ones(outputs.size(0), 1, 1, device=outputs.device) return outputs, sparse_weights
[docs]class PositionalEncoder(torch.nn.Module): def __init__(self, d_model, max_seq_len=160): super().__init__() assert d_model % 2 == 0, "model dimension has to be multiple of 2 (encode sin(pos) and cos(pos))" self.d_model = d_model pe = torch.zeros(max_seq_len, d_model) for pos in range(max_seq_len): for i in range(0, d_model, 2): pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) pe = pe.unsqueeze(0) self.register_buffer("pe", pe)
[docs] def forward(self, x): with torch.no_grad(): x = x * math.sqrt(self.d_model) seq_len = x.size(0) pe = self.pe[:, :seq_len].view(seq_len, 1, self.d_model) x = x + pe return x
[docs]class ScaledDotProductAttention(nn.Module): def __init__(self, dropout: float = None, scale: bool = True): super(ScaledDotProductAttention, self).__init__() if dropout is not None: self.dropout = nn.Dropout(p=dropout) else: self.dropout = dropout self.softmax = nn.Softmax(dim=2) self.scale = scale
[docs] def forward(self, q, k, v, mask=None): attn = torch.bmm(q, k.permute(0, 2, 1)) # query-key overlap if self.scale: dimension = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)) attn = attn / dimension if mask is not None: attn = attn.masked_fill(mask, -1e9) attn = self.softmax(attn) if self.dropout is not None: attn = self.dropout(attn) output = torch.bmm(attn, v) return output, attn
[docs]class InterpretableMultiHeadAttention(nn.Module): def __init__(self, n_head: int, d_model: int, dropout: float = 0.0): super(InterpretableMultiHeadAttention, self).__init__() self.n_head = n_head self.d_model = d_model self.d_k = self.d_q = self.d_v = d_model // n_head self.dropout = nn.Dropout(p=dropout) self.v_layer = nn.Linear(self.d_model, self.d_v) self.q_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_q) for _ in range(self.n_head)]) self.k_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_k) for _ in range(self.n_head)]) self.attention = ScaledDotProductAttention() self.w_h = nn.Linear(self.d_v, self.d_model, bias=False) self.init_weights() def init_weights(self): for name, p in self.named_parameters(): if "bias" not in name: torch.nn.init.xavier_uniform_(p) else: torch.nn.init.zeros_(p)
[docs] def forward(self, q, k, v, mask=None) -> Tuple[torch.Tensor, torch.Tensor]: heads = [] attns = [] vs = self.v_layer(v) for i in range(self.n_head): qs = self.q_layers[i](q) ks = self.k_layers[i](k) head, attn = self.attention(qs, ks, vs, mask) head_dropout = self.dropout(head) heads.append(head_dropout) attns.append(attn) head = torch.stack(heads, dim=2) if self.n_head > 1 else heads[0] attn = torch.stack(attns, dim=2) outputs = torch.mean(head, dim=2) if self.n_head > 1 else head outputs = self.w_h(outputs) outputs = self.dropout(outputs) return outputs, attn