Source code for genformer.models
import torch
import torch.nn as nn
import torch.nn.functional as F
import typing
from typing import Optional
import numpy as np
from darts.models.forecasting.transformer_model import TransformerModel, _TransformerModule
from genformer.metrics import energy_score_loss
from genformer.utils import GraphConv, energy_score_loss_st, calibration_loss, generate_forecasts
from genformer.noise import GaussianNoise, UniformNoise
class EnformerModule(_TransformerModule):
def __init__(self, input_size, output_size,
nr_params, *args, num_samples_engression=10, noise_dist="gaussian", noise_std=0.1, **kwargs):
super().__init__(input_size=input_size,
output_size=output_size,
nr_params=nr_params, *args, **kwargs)
self.M = num_samples_engression
self.noise_std = noise_std
self.noise_dist = noise_dist
if self.noise_dist == "gaussian":
self.encoder = nn.Sequential(
GaussianNoise(self.noise_std, 42),
self.encoder
)
elif self.noise_dist == "uniform":
self.encoder = nn.Sequential(
UniformNoise(self.noise_std, 42),
self.encoder
)
else:
raise ValueError("noise_dist must be either `gaussian` or `uniform`.")
def forward(self, x_in: tuple, *args, **kwargs):
x_in_float = tuple(
t.float() if isinstance(t, torch.Tensor) and torch.is_floating_point(t) else t
for t in x_in
)
return super().forward(x_in_float, *args, **kwargs)
def training_step(self, batch, batch_idx):
past_target = batch[0]
past_covs = batch[1]
static_covs = batch[4]
future_target = batch[-1]
batch_size = future_target.size(0)
past_target_m = past_target.repeat_interleave(self.M, dim=0)
if past_covs is not None:
past_covs_m = past_covs.repeat_interleave(self.M, dim=0)
data_m = torch.cat([past_target_m, past_covs_m], dim=2)
else:
data_m = past_target_m
fut_covs_m = batch[3].repeat_interleave(self.M, dim=0) if batch[3] is not None else None
static_covs_m = static_covs.repeat_interleave(self.M, dim=0) if static_covs is not None else None
data_m = data_m.float()
x_expanded = (data_m, fut_covs_m, static_covs_m)
y_hat_raw = self(x_expanded)
samples = y_hat_raw.view(batch_size, self.M, y_hat_raw.shape[1], y_hat_raw.shape[2])
samples = samples.permute(1, 0, 2, 3) # (M, B, T, D)
loss = energy_score_loss(samples, future_target)
self.log("energy_score_train_loss", loss, prog_bar=True, on_epoch=True)
return loss
[docs]
class Enformer(TransformerModel):
r"""
Enformer: A Deep Generative Transformer for Probabilistic Time Series Forecasting.
This model integrates the engression (distributional regression) principle with
a sequence-to-sequence Transformer architecture. Instead of producing point
predictions or relying on restrictive parametric likelihoods, Enformer
directly estimates the conditional predictive distribution of future observations
via noise-driven sampling.
The architecture explicitly expands the input look-back sequence into an ensemble
of `M` replicas, injecting independent stochastic noise into each before processing
through the Transformer backbone. The training optimizes a strictly proper scoring
rule known as the Energy Score (ES) loss.
Args:
input_chunk_length (int): The length of the historical look-back window (\( p \)).
output_chunk_length (int): The prediction horizon (\( q \)).
num_samples_engression (int, optional): The number of in-sample forecast
trajectories (\( M \)) to generate for the ensemble. Defaults to 10.
noise_dist (str, optional): The type of noise to inject ('gaussian' or 'uniform').
Defaults to 'gaussian'.
noise_std (float, optional): The standard deviation/scale (\( \sigma \)) of
the injected noise. Defaults to 0.1.
random_state (int, optional): Seed for reproducibility. Defaults to 23.
save_checkpoints (bool, optional): Whether to save PyTorch Lightning checkpoints.
Defaults to False.
**kwargs: Additional parameters passed to the underlying `TransformerModel`
from Darts (e.g., `n_heads`, `d_model`, `num_encoder_layers`).
"""
def __init__(self,
input_chunk_length: int,
output_chunk_length: int,
output_chunk_shift: int = 0,
d_model: int = 64,
nhead: int = 4,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
dim_feedforward: int = 512,
dropout: float = 0.1,
activation: str = "relu",
norm_type: typing.Union[str, torch.nn.Module, None] = None,
custom_encoder: typing.Optional[torch.nn.Module] = None,
custom_decoder: typing.Optional[torch.nn.Module] = None,
num_samples_engression: int = 10,
noise_dist: str = "gaussian",
noise_std: float = 0.1,
random_state: typing.Optional[int] = 23,
save_checkpoints: bool = False,
**kwargs
):
if num_samples_engression <= 0:
raise ValueError("num_samples_engression must be positive.")
if noise_dist not in ["gaussian", "uniform"]:
raise ValueError("noise_dist must be either `gaussian` or `uniform`.")
self.num_samples_engression = num_samples_engression
self.noise_std = noise_std
self.noise_dist = noise_dist
# Lightning Trainer arguments
pl_kwargs = {
"deterministic": True,
"logger": False,
"enable_progress_bar": False,
"enable_model_summary": False,
}
# Determine whether to allow lightning to handle epoch checkpoints (or just run cleanly)
if save_checkpoints:
pl_kwargs["enable_checkpointing"] = True
pl_kwargs["default_root_dir"] = "checkpoints"
else:
pl_kwargs["enable_checkpointing"] = False
super().__init__(input_chunk_length=input_chunk_length,
output_chunk_length=output_chunk_length,
output_chunk_shift=output_chunk_shift,
random_state = random_state,
pl_trainer_kwargs=pl_kwargs,
**kwargs)
self.model_params.pop('num_samples_engression', None)
self.model_params.pop('random_state', None)
self.model_params.pop('noise_std', None)
self.model_params.pop('noise_dist', None)
self.model_params.pop('save_checkpoints', None)
self.model_params.pop('batch_size', None)
self.model_params.pop('n_epochs', None)
self.model_params.pop('optimizer_kwargs', None)
self.model_params.pop('lr_scheduler_cls', None)
self.model_params.pop('lr_scheduler_kwargs', None)
def _create_model(self, train_sample):
past_target = train_sample[0]
past_covs = train_sample[1]
target_dim = past_target.shape[-1]
past_cov_dim = past_covs.shape[-1] if past_covs is not None else 0
input_size = target_dim + past_cov_dim
future_target = train_sample[-1]
output_size = future_target.shape[-1]
return EnformerModule(
input_size=input_size,
output_size=output_size,
nr_params=getattr(self, "nr_params", 1),
num_samples_engression=self.num_samples_engression,
noise_std=self.noise_std,
noise_dist=self.noise_dist,
**self.model_params
)
class GEnformerModule(_TransformerModule):
def __init__(
self,
input_size,
output_size,
nr_params,
edges,
num_nodes,
gcn_out_feat,
node_feat_dim: int,
num_samples_engression=10,
noise_dist="gaussian",
noise_std=1,
target_coverage=0.95,
graph_conv_params: typing.Optional[dict] = None,
lambda_calib=0,
*args,
**kwargs
):
self.activation_name = kwargs.get("activation", "relu")
kwargs.pop("activation", None)
self.temperature = kwargs.pop("temperature", 10.0)
self.width_penalty = kwargs.pop("width_penalty", 0.5)
kwargs.pop("pl_trainer_kwargs", None)
super().__init__(
input_size=input_size,
output_size=output_size,
nr_params=nr_params,
activation=self.activation_name,
*args,
**kwargs
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(device)
self.M = num_samples_engression
self.noise_std = noise_std
self.alpha = target_coverage # target coverage
self.lambda_calib = lambda_calib # calibration strength
self.noise_dist = noise_dist
self.edges = edges
self.num_nodes = num_nodes
self.gcn_out_feat = gcn_out_feat
if graph_conv_params is None:
graph_conv_params = {
"aggregation_type": "mean",
"combination_type": "add",
"activation": None,
}
# graph convolution
gcn_activation = graph_conv_params.get("activation", None)
self.gcn = GraphConv(
in_feat=node_feat_dim,
out_feat=gcn_out_feat,
edges=self.edges,
num_nodes=self.num_nodes,
aggregation_type=graph_conv_params.get("aggregation_type", "mean"),
combination_type=graph_conv_params.get("combination_type", "add"),
activation=gcn_activation
).to(device)
# noise injection
if self.noise_dist == "gaussian":
self.encoder = nn.Sequential(GaussianNoise(self.noise_std, 42).to(device), self.encoder)
elif self.noise_dist == "uniform":
self.encoder = nn.Sequential(UniformNoise(self.noise_std, 42).to(device), self.encoder)
else:
raise ValueError("noise_dist must be either `gaussian` or `uniform`.")
def forward(self, x_in, *args, **kwargs):
x_in_float = tuple(
t.float() if isinstance(t, torch.Tensor) and torch.is_floating_point(t) else t
for t in x_in
)
return super().forward(x_in_float, *args, **kwargs)
def training_step(self, batch, batch_idx):
device = next(self.parameters()).device
past_target = batch[0].to(device) # (B, T, F)
past_covs = batch[1].to(device) if batch[1] is not None else None
static_covs = batch[4].to(device) if batch[4] is not None else None
future_target = batch[-1].to(device) # (B, T_out, F)
# recover (N, D) structure
B, T, F = past_target.shape
N = self.num_nodes
D = F // N
past_target = past_target.view(B, T, N, D)
T_out = future_target.shape[1]
future_target = future_target.view(B, T_out, N, D)
# apply GCN - (N, B, T, D)
x = past_target.permute(2, 0, 1, 3)
x = self.gcn(x) # (N, B, T, gcn_out)
# back to (B, T, N, gcn_out)
x = x.permute(1, 2, 0, 3)
# flatten for Transformer
B, T, N, Dg = x.shape
x = x.reshape(B, T, N * Dg)
# engression sampling
x_m = x.repeat_interleave(self.M, dim=0)
if past_covs is not None:
past_covs_m = past_covs.repeat_interleave(self.M, dim=0).to(device)
data_m = torch.cat([x_m, past_covs_m], dim=2)
else:
data_m = x_m
fut_covs_m = batch[3].to(device).repeat_interleave(self.M, dim=0) if batch[3] is not None else None
static_covs_m = static_covs.repeat_interleave(self.M, dim=0) if static_covs is not None else None
# forward
x_expanded = (data_m.float(), fut_covs_m, static_covs_m)
y_hat_raw = self(x_expanded) # (B*M, T_out, F_out)
# reshape output back
y_hat_raw = y_hat_raw.view(B, self.M, T_out, N, self.gcn_out_feat)
samples = y_hat_raw.permute(1, 0, 2, 3, 4) # (M, B, T_out, N, D)
# energy score
es_loss = energy_score_loss_st(future_target, samples)
# calibration (only if enabled)
if self.lambda_calib > 0:
cal_loss, coverage = calibration_loss(
future_target,
samples,
alpha=self.alpha,
temperature=self.temperature,
width_penalty=self.width_penalty
)
loss = es_loss + self.lambda_calib * cal_loss
else:
loss = es_loss
cal_loss = torch.tensor(0.0, device=es_loss.device)
coverage = torch.tensor(float("nan"), device=es_loss.device)
# final loss
if torch.isnan(loss) or torch.isinf(loss):
loss = torch.tensor(1e6, device=future_target.device)
# logging
self.log("train_loss", loss, prog_bar=True, on_epoch=True)
self.log("train_es", es_loss, on_epoch=True)
self.log("train_calib", cal_loss, on_epoch=True)
self.log("train_coverage", coverage, prog_bar=True, on_epoch=True)
self.log("noise_std", self.noise_std, on_epoch=True)
return loss
[docs]
class GEnformer(TransformerModel):
r"""
Graph-Enformer (GEnformer): Spatiotemporal Probabilistic Forecasting.
An extension of the Enformer architecture for spatiotemporal contexts where
spatial locations are represented as nodes in an interconnected graph. This model
jointly captures temporal dynamics, complex spatial interactions, and predictive
uncertainty.
Before noise injection and Transformer processing, GEnformer applies a
Graph Convolutional Network (GCN) layer to map the target observations and the
spatial topology (defined by an adjacency matrix) to spatially-aware latent embeddings.
It then optimizes the Energy Score alongside an optional calibration objective.
Args:
input_chunk_length (int): The length of the historical look-back window (\( p \)).
output_chunk_length (int): The prediction horizon (\( q \)).
edges (torch.Tensor or List): Graph edges defining the connectivity between nodes.
num_nodes (int): The number of spatial locations (\( D \)).
output_chunk_shift (int, optional): Shift for the output chunk. Defaults to 0.
gcn_out_feat (int, optional): The dimensionality of the latent spatial
embeddings outputted by the GCN. Defaults to 32.
graph_conv_params (dict, optional): Dictionary of parameters for the `GraphConv` layer
(e.g., aggregation and combination types). Defaults to None.
num_samples_engression (int, optional): The number of in-sample forecast
trajectories (\( M \)) for the ensemble. Defaults to 10.
noise_dist (str, optional): The type of noise to inject ('gaussian' or 'uniform').
Defaults to 'gaussian'.
target_coverage (float, optional): The target prediction interval coverage
used for the calibration loss term. Defaults to 0.9.
noise_std (float, optional): The standard deviation/scale (\( \sigma \)) of
the injected noise. Defaults to 1.
random_state (int, optional): Seed for reproducibility. Defaults to 23.
save_checkpoints (bool, optional): Whether to save PyTorch Lightning checkpoints.
Defaults to False.
lambda_calib (float, optional): The weight of the calibration loss term in
the overall optimization objective. Defaults to 2.
**kwargs: Additional parameters passed to the underlying `TransformerModel`.
"""
def __init__(
self,
input_chunk_length: int,
output_chunk_length: int,
edges,
num_nodes: int,
output_chunk_shift: int = 0,
d_model: int = 64,
nhead: int = 4,
num_encoder_layers: int = 3,
num_decoder_layers: int = 3,
dim_feedforward: int = 512,
dropout: float = 0.1,
activation: str = "relu",
norm_type: typing.Union[str, torch.nn.Module, None] = None,
custom_encoder: typing.Optional[torch.nn.Module] = None,
custom_decoder: typing.Optional[torch.nn.Module] = None,
gcn_out_feat: int = 32,
graph_conv_params: typing.Optional[dict] = None,
num_samples_engression: int = 10,
noise_dist: str = "gaussian",
target_coverage: float = 0.9,
noise_std: float = 1,
random_state: int = 23,
save_checkpoints: bool = False,
lambda_calib: float = 2,
**kwargs
):
self.num_samples_engression = num_samples_engression
self.noise_std = noise_std
self.noise_dist = noise_dist
self.edges = edges
self.target_coverage = target_coverage
self.num_nodes = num_nodes
self.gcn_out_feat = gcn_out_feat
self.random_state = random_state
self.save_checkpoints = save_checkpoints
self.graph_conv_params = graph_conv_params
self.lambda_calib = lambda_calib
self.temperature = kwargs.get("temperature", 5.0)
self.width_penalty = kwargs.get("width_penalty", 0.01)
for k in ["temperature", "width_penalty", "target_coverage", "lambda_calib"]:
kwargs.pop(k, None)
# Lightning Trainer arguments
pl_kwargs = {
"accelerator": "cuda" if torch.cuda.is_available() else "cpu",
"devices": 1,
"deterministic": False,
"logger": False,
"enable_progress_bar": False,
"enable_model_summary": True
}
pl_trainer_kwargs = kwargs.pop("pl_trainer_kwargs", None)
if pl_trainer_kwargs is not None:
pl_kwargs.update(pl_trainer_kwargs)
# Determine whether to allow lightning to handle epoch checkpoints
if save_checkpoints:
pl_kwargs["enable_checkpointing"] = True
pl_kwargs["default_root_dir"] = "checkpoints"
else:
pl_kwargs["enable_checkpointing"] = False
safe_kwargs = {k: v for k, v in kwargs.items() if k not in ["temperature", "width_penalty", "target_coverage"]}
super().__init__(
input_chunk_length=input_chunk_length,
output_chunk_length=output_chunk_length,
output_chunk_shift=output_chunk_shift,
random_state=random_state,
pl_trainer_kwargs=pl_kwargs,
**safe_kwargs
)
keys_to_remove = [
'edges', 'num_nodes', 'gcn_out_feat', 'num_samples_engression',
'noise_dist', 'noise_std', 'graph_conv_params', 'batch_size',
'n_epochs', 'random_state', 'save_checkpoints', 'node_feat_dim',
'target_coverage', 'temperature', 'width_penalty'
]
for key in keys_to_remove:
self.model_params.pop(key, None)
def _create_model(self, train_sample):
past_target = train_sample[0]
target_dim = past_target.shape[-1] # F = N * D
N = self.num_nodes
assert target_dim % N == 0, f"Target dimension {target_dim} must be divisible by num_nodes {N}"
D = target_dim // N
past_covs = train_sample[1]
cov_dim = past_covs.shape[-1] if past_covs is not None else 0
input_size = self.gcn_out_feat * self.num_nodes + cov_dim
future_target = train_sample[-1]
output_size = self.num_nodes * self.gcn_out_feat
# Create a clean copy of model_params for the Module
module_params = self.model_params.copy()
module_params.pop("pl_trainer_kwargs", None)
return GEnformerModule(
input_size=input_size,
output_size=output_size,
nr_params=getattr(self, "nr_params", 1),
edges=self.edges,
num_nodes=self.num_nodes,
gcn_out_feat=self.gcn_out_feat,
node_feat_dim=D,
num_samples_engression=self.num_samples_engression,
noise_std=self.noise_std,
noise_dist=self.noise_dist,
graph_conv_params=self.graph_conv_params,
target_coverage=self.target_coverage,
temperature=self.temperature,
width_penalty=self.width_penalty,
**module_params
)