Enformer#

Enformer is the temporal model in Genformer. It integrates the engression (distributional regression) principle with a sequence-to-sequence Transformer to produce probabilistic multivariate forecasts.

How it works#

1 · Expand

The look-back window is replicated into \(M\) ensemble copies.

2 · Perturb

Independent pre-additive noise \(\epsilon^{(m)} \sim \mathcal{N}(0, \sigma^2 \mathbf{I})\) is injected into each copy.

3 · Score

The Transformer maps each copy to a trajectory; training minimises the strictly proper Energy Score.

Because the Energy Score is a strictly proper scoring rule, minimising it drives the sampled trajectories toward the true conditional predictive distribution - no parametric likelihood assumption required.

Usage#

Enformer follows the standard Darts fit / predict workflow over a TimeSeries:

import pandas as pd
from darts import TimeSeries
from genformer import Enformer

# 1. Load data as a Darts TimeSeries
series = TimeSeries.from_dataframe(pd.read_csv("your_data.csv"))
train, val = series[:-50], series[-50:]

# 2. Configure the model
model = Enformer(
    input_chunk_length=24,        # look-back window  (p)
    output_chunk_length=12,       # forecast horizon  (q)
    num_samples_engression=10,    # ensemble size     (M)
    noise_dist="gaussian",        # 'gaussian' or 'uniform'
    noise_std=0.1,                # noise scale        (σ)
    n_epochs=30,
)

# 3. Train
model.fit(train)

# 4. Draw a probabilistic forecast (50 sampled trajectories)
prediction = model.predict(n=12, num_samples=50)

# 5. Plot the predictive interval
prediction.plot(low_quantile=0.05, high_quantile=0.95)

Where the uncertainty comes from

Every predict call injects noise into the look-back window, so num_samples independent trajectories are produced. Quantiles over those samples form the predictive interval.

For a full, runnable walkthrough see Usage.

Key hyperparameters#

Argument

Symbol

Meaning

input_chunk_length

\(p\)

Length of the historical look-back window.

output_chunk_length

\(q\)

Forecast horizon.

num_samples_engression

\(M\)

In-sample ensemble size used to estimate the Energy Score in training.

noise_std

\(\sigma\)

Scale of the injected stochastic noise.

noise_dist

Noise family: "gaussian" or "uniform".

API reference#

class genformer.models.Enformer(*args, **kwargs)[source]#

Bases: TransformerModel

Enformer: A Deep Generative Transformer for Probabilistic Time Series Forecasting.

This model integrates the engression (distributional regression) principle with a sequence-to-sequence Transformer architecture. Instead of producing point predictions or relying on restrictive parametric likelihoods, Enformer directly estimates the conditional predictive distribution of future observations via noise-driven sampling.

The architecture explicitly expands the input look-back sequence into an ensemble of M replicas, injecting independent stochastic noise into each before processing through the Transformer backbone. The training optimizes a strictly proper scoring rule known as the Energy Score (ES) loss.

Parameters:
  • input_chunk_length (int) – The length of the historical look-back window (( p )).

  • output_chunk_length (int) – The prediction horizon (( q )).

  • num_samples_engression (int, optional) – The number of in-sample forecast trajectories (( M )) to generate for the ensemble. Defaults to 10.

  • noise_dist (str, optional) – The type of noise to inject (‘gaussian’ or ‘uniform’). Defaults to ‘gaussian’.

  • noise_std (float, optional) – The standard deviation/scale (( sigma )) of the injected noise. Defaults to 0.1.

  • random_state (int, optional) – Seed for reproducibility. Defaults to 23.

  • save_checkpoints (bool, optional) – Whether to save PyTorch Lightning checkpoints. Defaults to False.

  • **kwargs – Additional parameters passed to the underlying TransformerModel from Darts (e.g., n_heads, d_model, num_encoder_layers).

  • output_chunk_shift (int)

  • d_model (int)

  • nhead (int)

  • num_encoder_layers (int)

  • num_decoder_layers (int)

  • dim_feedforward (int)

  • dropout (float)

  • activation (str)

  • norm_type (str | torch.nn.Module | None)

  • custom_encoder (torch.nn.Module | None)

  • custom_decoder (torch.nn.Module | None)