Source code for pytorch_forecasting.models.nn.embeddings

from typing import Dict, List, Tuple

import torch.nn as nn


[docs]class TimeDistributedEmbeddingBag(nn.EmbeddingBag): def __init__(self, *args, batch_first: bool = False, **kwargs): super().__init__(*args, **kwargs) self.batch_first = batch_first
[docs] def forward(self, x): if len(x.size()) <= 2: return super().forward(x) # Squash samples and timesteps into a single axis x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) y = super().forward(x_reshape) # We have to reshape Y if self.batch_first: y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) else: y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) return y
[docs]class MultiEmbedding(nn.Module): def __init__( self, embedding_sizes: Dict[str, Tuple[int, int]], categorical_groups: Dict[str, List[str]], embedding_paddings: List[str], x_categoricals: List[str], max_embedding_size: int = None, ): super().__init__() self.embedding_sizes = embedding_sizes self.categorical_groups = categorical_groups self.embedding_paddings = embedding_paddings self.max_embedding_size = max_embedding_size self.x_categoricals = x_categoricals self.init_embeddings() def init_embeddings(self): self.embeddings = nn.ModuleDict() for name in self.embedding_sizes.keys(): embedding_size = self.embedding_sizes[name][1] if self.max_embedding_size is not None: embedding_size = min(embedding_size, self.max_embedding_size) # convert to list to become mutable self.embedding_sizes[name] = list(self.embedding_sizes[name]) self.embedding_sizes[name][1] = embedding_size if name in self.categorical_groups: # embedding bag if related embeddings self.embeddings[name] = TimeDistributedEmbeddingBag( self.embedding_sizes[name][0], embedding_size, mode="sum", batch_first=True ) else: if name in self.embedding_paddings: padding_idx = 0 else: padding_idx = None self.embeddings[name] = nn.Embedding( self.embedding_sizes[name][0], embedding_size, padding_idx=padding_idx, ) def names(self): return list(self.keys()) def items(self): return self.embeddings.items() def keys(self): return self.embeddings.keys() def values(self): return self.embeddings.values() def __getitem__(self, name: str): return self.embeddings[name]
[docs] def forward(self, x): input_vectors = {} for name, emb in self.embeddings.items(): if name in self.categorical_groups: input_vectors[name] = emb( x[ ..., [self.x_categoricals.index(cat_name) for cat_name in self.categorical_groups[name]], ] ) else: input_vectors[name] = emb(x[..., self.x_categoricals.index(name)]) return input_vectors