Source code for genformer.utils

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import typing
from typing import Optional
try:
    import cv2
except ImportError:
    cv2 = None
try:
    from transformers import set_seed
except ImportError:
    set_seed = None

try:
    from datasets import disable_progress_bar
except ImportError:
    disable_progress_bar = None
import logging

def setup_logger():
    # Lightning 2.x
    logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
    # Older Lightning versions
    logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
    
    class IgnoreDartsIndexError(logging.Filter):
        def filter(self, record):
            msg = record.getMessage()
            return "Integer index out of range" not in msg

    logger = logging.getLogger("darts.timeseries")
    logger.addFilter(IgnoreDartsIndexError())

class Deterministic:
    def __init__(self):
        pass

    def init_all(self, seed=0, disable_list=['cuda_block']):
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
        if 'cuda_block' not in disable_list: # stuck when train deberta
            os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
        os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        if 'torch_deter_algo' not in disable_list: # consumn more gpu sometimes
            torch.use_deterministic_algorithms(True, warn_only=True)
        if set_seed is not None:
            set_seed(seed)
        if cv2 is not None:
            cv2.setRNGSeed(seed)
        if disable_progress_bar is not None:
            disable_progress_bar()

deterministic = Deterministic()


class GraphConv(nn.Module):
    def __init__(
        self,
        in_feat: int,
        out_feat: int,
        edges,
        num_nodes,
        aggregation_type: str = "mean",
        combination_type: str = "add",
        activation: Optional[str] = None
    ):
        super(GraphConv, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.register_buffer("src_nodes", torch.tensor(edges[0], dtype=torch.long))
        self.register_buffer("dst_nodes", torch.tensor(edges[1], dtype=torch.long))
        self.num_nodes = num_nodes
        self.aggregation_type = aggregation_type
        self.combination_type = combination_type

        # weight parameter
        self.weight_self = nn.Parameter(torch.empty(in_feat, out_feat))
        self.weight_neigh = nn.Parameter(torch.empty(in_feat, out_feat))
        # Xavier Glorot initialization
        nn.init.xavier_uniform_(self.weight_self)
        nn.init.xavier_uniform_(self.weight_neigh)

        self.norm = nn.LayerNorm(out_feat)

        # Activation
        if activation is None:
            self.activation = lambda x: x
        elif isinstance(activation, str):
            self.activation = getattr(F, activation)
        else:
            raise ValueError(f"Unsupported activation: {activation}")

    def aggregate(self, neighbour_representations: torch.Tensor):
        src_nodes = self.src_nodes
        dst_nodes = self.dst_nodes
        num_nodes = self.num_nodes

        if self.aggregation_type == "sum":
            aggregated = torch.zeros(
                (num_nodes,) + neighbour_representations.shape[1:],
                device=neighbour_representations.device
            )
            aggregated.index_add_(0, src_nodes, neighbour_representations)

        elif self.aggregation_type == "mean":
            aggregated = torch.zeros(
                (num_nodes,) + neighbour_representations.shape[1:],
                device=neighbour_representations.device
            )
            counts = torch.zeros(num_nodes, device=neighbour_representations.device).float()
            aggregated.index_add_(0, src_nodes, neighbour_representations)
            counts.index_add_(0, src_nodes, torch.ones_like(src_nodes, dtype=torch.float))
            counts = counts.clamp(min=1).view(-1, *([1]*(neighbour_representations.dim()-1)))
            aggregated = aggregated / counts

        elif self.aggregation_type == "max":
            # Group-wise max per node
            aggregated = torch.full(
                (num_nodes,) + neighbour_representations.shape[1:],
                float('-inf'), device=neighbour_representations.device
            )
            for i in range(src_nodes.shape[0]):
                aggregated[src_nodes[i]] = torch.max(
                    aggregated[src_nodes[i]], neighbour_representations[i]
                )

        else:
            raise ValueError(f"Invalid aggregation type: {self.aggregation_type}")

        return aggregated

    def compute_nodes_representation(self, features: torch.Tensor):
        """
        features: (num_nodes, batch_size, input_seq_len, in_feat)
        returns: (num_nodes, batch_size, input_seq_len, out_feat)
        """
        return torch.matmul(features, self.weight_self)  # last-dim matmul

    def compute_aggregated_messages(self, features: torch.Tensor):
        dst_nodes = self.dst_nodes
        # gather neighbors
        neighbour_representations = features[dst_nodes]
        aggregated_messages = self.aggregate(neighbour_representations)
        return torch.matmul(aggregated_messages, self.weight_neigh)

    def update(self, nodes_representation: torch.Tensor, aggregated_messages: torch.Tensor):
        if self.combination_type == "concat":
            h = torch.cat([nodes_representation, aggregated_messages], dim=-1)
        elif self.combination_type == "add":
            h = nodes_representation + aggregated_messages
        else:
            raise ValueError(f"Invalid combination type: {self.combination_type}")
        h = self.activation(h)
        self.norm = self.norm.to(h.device)
        h = self.norm(h)
        return h

    def forward(self, features: torch.Tensor):
        """
        features: (num_nodes, batch_size, input_seq_len, in_feat)
        returns:  (num_nodes, batch_size, input_seq_len, out_feat)
        """
        nodes_representation = self.compute_nodes_representation(features)
        aggregated_messages = self.compute_aggregated_messages(features)
        return self.update(nodes_representation, aggregated_messages)


def energy_score_loss_st(y_true, y_pred_samples):
    """
    Calculates the Energy Score loss for spatiotemporal targets.
    Args:
        y_pred_samples (torch.Tensor): Shape (M, B, T_out, N, D), M = num_samples
        y_true (torch.Tensor): Shape (B, T_out, N, D)
    """
    M, B, T, N, D = y_pred_samples.shape
    y_pred_samples = y_pred_samples.permute(1, 0, 2, 3, 4) # (B, M, T, N, D)
    y_true_expanded = y_true.unsqueeze(1) # (B, 1, T, N, D)

    # First term: E||Y_hat - y||
    term1 = torch.norm(y_pred_samples - y_true_expanded, p=2, dim=(-3, -2, -1)).mean(dim=1)

    # Second term: 0.5 * E||Y_hat - Y_hat'||
    y_pred_flat = y_pred_samples.reshape(B, M, -1) # -> (B, M, T*N*D)

    # Calculate pairwise distances
    diff = y_pred_flat.unsqueeze(2) - y_pred_flat.unsqueeze(1)   # (B, M, M, dim)
    eps = torch.tensor(1e-8, device=diff.device)
    dist_matrix = torch.sqrt((diff ** 2).sum(-1) + eps)
    term2 = 0.5 * dist_matrix.mean(dim=(1, 2))

    return (term1 - term2).mean()


def calibration_loss(future_target, samples, alpha, temperature, width_penalty):
    lower_q = (1 - alpha) / 2
    upper_q = 1 - lower_q
    lower = torch.quantile(samples, lower_q, dim=0)
    upper = torch.quantile(samples, upper_q, dim=0)
    
    # soft coverage
    soft_coverage = torch.sigmoid(temperature * (future_target - lower)) * torch.sigmoid(temperature * (upper - future_target))
    soft_coverage = torch.clamp(soft_coverage, min=1e-6, max=1-1e-6).mean()

    # interval width regularization
    width = (upper - lower).abs().mean()
    calibration_loss = (soft_coverage - alpha) ** 2 - width_penalty * width
    coverage = (((future_target >= lower) & (future_target <= upper)).float().mean(dim=(0, 2, 3)).mean())

    return calibration_loss, coverage


def get_values_safe(ts_or_tensor):
    if hasattr(ts_or_tensor, "all_values"):
        # Darts TimeSeries
        return ts_or_tensor.all_values()
    else:
        # Already a tensor or numpy array
        return ts_or_tensor


[docs] def generate_forecasts(model, history, m_samples, past_covs=None, fut_covs=None, static_covs=None, unstandardize=None, device="cpu"): net = model.model.to(device) if hasattr(model, 'model') else model.to(device) net.eval() # Get the 2D values: (Time, Total_Features) h_values = get_values_safe(history) # tensor of shape (T, N, D) if isinstance(h_values, np.ndarray): h_values = torch.from_numpy(h_values).float() elif not isinstance(h_values, torch.Tensor): h_values = torch.as_tensor(h_values, dtype=torch.float32) if h_values.ndim == 2: # D missing h_values = h_values.unsqueeze(-1) # add D=1 # Correct Reshaping logic: split Total_Features into Nodes and Node_Features T_len, N, D = h_values.shape # Create (1, T, N, D) history_tensor = h_values.view(1, T_len, N, D) history_tensor = history_tensor.to(device) forecasts = [] with torch.no_grad(): for _ in range(m_samples): x = history_tensor # 4D: (1, T, N, D) B, T, N, D = x.shape # Spatial processing (GCN) x = x.permute(2, 0, 1, 3) # (N, B, T, D) x = net.gcn(x) x = x.permute(1, 2, 0, 3) # (B, T, N, Gcn_Out) # Temporal processing: flatten spatial dim into features for the Transformer x = x.reshape(B, T, N * net.gcn_out_feat) x_exp = (x.float(), past_covs, fut_covs) y = net(x_exp) # Reshape output back to spatial dimensions (B, T_out, N, D_out(gcn_out_feat)) y = y.view(B, -1, N, net.gcn_out_feat) forecasts.append(y.squeeze(0)) # Remove batch dim for stacking # Stack to get (m_samples, T_out, N, D_out) forecasts = torch.stack(forecasts, dim=0) forecasts = forecasts.to(device) # Unstandardize if provided if unstandardize is not None: mean, std = unstandardize mean = torch.as_tensor(mean, dtype=torch.float32, device=device) std = torch.as_tensor(std, dtype=torch.float32, device=device) # Reshape mean/std to (1, 1, N, D) for broadcasting mean = mean.view(1, 1, N, -1) std = std.view(1, 1, N, -1) forecasts = forecasts * std + mean return forecasts