Temporal Forecasting with Enformer#
This example demonstrates how to use the Enformer from the genformer package to predict multivariate temporal data. We’ll generate a dummy sine wave dataset and create a beautiful probabilistic forecast plot.
[1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from darts import TimeSeries
from genformer.models import Enformer
# Style for beautiful plots
plt.style.use('seaborn-v0_8-darkgrid')
/home/docs/checkouts/readthedocs.org/user_builds/genformer/envs/latest/lib/python3.12/site-packages/gluonts/json.py:102: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.
warnings.warn(
[2]:
# Generate dummy multivariate sine wave data
time_steps = 200
x = np.linspace(0, 50, time_steps)
series1 = (np.sin(x) + np.random.normal(0, 0.1, time_steps)).astype(np.float32)
series2 = (np.cos(x) + np.random.normal(0, 0.1, time_steps)).astype(np.float32)
df = pd.DataFrame({'sin_wave': series1, 'cos_wave': series2})
series = TimeSeries.from_dataframe(df)
# Split data
train, val = series[:-50], series[-50:]
# Plot the dummy data
plt.figure(figsize=(10, 4))
train.plot(label='Training')
val.plot(label='Validation')
plt.title('Dummy Multivariate Time Series Data')
plt.show()
[3]:
# Initialize Enformer
model = Enformer(
input_chunk_length=24,
output_chunk_length=12,
num_samples_engression=10,
n_epochs=2, # Keep low for quick demo
batch_size=16
)
# Train the model (demo purposes)
model.fit(train)
/home/docs/checkouts/readthedocs.org/user_builds/genformer/envs/latest/lib/python3.12/site-packages/torch/nn/modules/transformer.py:143: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
self.encoder = TransformerEncoder(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
| Name | Type | Params | Mode | FLOPs
----------------------------------------------------------------------------
0 | criterion | MSELoss | 0 | train | 0
1 | train_criterion | MSELoss | 0 | train | 0
2 | val_criterion | MSELoss | 0 | train | 0
3 | train_metrics | MetricCollection | 0 | train | 0
4 | val_metrics | MetricCollection | 0 | train | 0
5 | encoder | Sequential | 192 | train | 0
6 | positional_encoding | _PositionalEncoding | 0 | train | 0
7 | transformer | Transformer | 548 K | train | 0
8 | decoder | Linear | 1.6 K | train | 0
----------------------------------------------------------------------------
550 K Trainable params
0 Non-trainable params
550 K Total params
2.201 Total estimated model params size (MB)
90 Modules in train mode
0 Modules in eval mode
0 Total Flops
/home/docs/checkouts/readthedocs.org/user_builds/genformer/envs/latest/lib/python3.12/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/home/docs/checkouts/readthedocs.org/user_builds/genformer/envs/latest/lib/python3.12/site-packages/torch/utils/data/dataloader.py:752: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
super().__init__(loader)
Epoch 1: 100%|██████████| 8/8 [00:01<00:00, 4.98it/s, energy_score_train_loss_step=0.199, energy_score_train_loss_epoch=0.290]
`Trainer.fit` stopped: `max_epochs=2` reached.
Epoch 1: 100%|██████████| 8/8 [00:01<00:00, 4.97it/s, energy_score_train_loss_step=0.199, energy_score_train_loss_epoch=0.290]
[3]:
Enformer(output_chunk_shift=0, d_model=64, nhead=4, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1, activation=relu, norm_type=None, custom_encoder=None, custom_decoder=None, input_chunk_length=24, output_chunk_length=12)