Source code for spflow.zoo.cms.rqmc
"""Randomized quasi-Monte Carlo (RQMC) utilities.
This module provides small, dependency-free helpers to generate low-discrepancy
integration points for continuous mixtures, following the paper:
"Continuous Mixtures of Tractable Probabilistic Models"
Correia et al., 2023
We currently implement Sobol-based RQMC with a random shift at each call and an
inverse-CDF transform to map from U(0,1)^d to N(0, I).
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor
from spflow.exceptions import InvalidParameterError
[docs]
@dataclass(frozen=True)
class RqmcPoints:
"""RQMC integration points and weights.
The returned tensors follow the usual numerical integration convention:
- ``z`` has shape ``(num_points, latent_dim)``
- ``weights`` has shape ``(num_points,)`` and sums to 1
"""
z: Tensor
weights: Tensor
[docs]
def rqmc_sobol_normal(
*,
num_points: int,
latent_dim: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
seed: int | None = None,
eps: float = 1e-6,
) -> RqmcPoints:
"""Generate Sobol-RQMC points for a standard normal prior N(0, I).
Uses a Sobol low-discrepancy sequence in (0,1)^d, applies a random shift
modulo 1, and maps to N(0,I) via inverse CDF (icdf).
Args:
num_points: Number of integration points N.
latent_dim: Latent dimension d.
device: Target device.
dtype: Target dtype for returned tensors.
seed: Optional seed used to sample the random shift. If provided, this
yields deterministic points.
eps: Clamp epsilon for icdf numerical stability.
Returns:
RqmcPoints containing (z, weights).
"""
if num_points < 1:
raise InvalidParameterError("num_points must be >= 1.")
if latent_dim < 1:
raise InvalidParameterError("latent_dim must be >= 1.")
if eps <= 0 or eps >= 0.1:
raise InvalidParameterError("eps must be in (0, 0.1).")
device = torch.device("cpu") if device is None else device
dtype = torch.get_default_dtype() if dtype is None else dtype
# SobolEngine produces a deterministic low-discrepancy sequence.
engine = torch.quasirandom.SobolEngine(dimension=latent_dim, scramble=False, seed=0)
u = engine.draw(num_points).to(device=device, dtype=dtype) # (N,d) in [0,1)
gen = None
if seed is not None:
gen = torch.Generator(device=device)
gen.manual_seed(int(seed))
# Random shift modulo 1 (RQMC).
shift = torch.rand((1, latent_dim), generator=gen, device=device, dtype=dtype)
u = torch.remainder(u + shift, 1.0)
u = u.clamp(min=eps, max=1.0 - eps)
# Inverse-CDF transform to standard normal.
normal = torch.distributions.Normal(
loc=torch.zeros((), device=device, dtype=dtype),
scale=torch.ones((), device=device, dtype=dtype),
)
z = normal.icdf(u)
weights = torch.full((num_points,), 1.0 / float(num_points), device=device, dtype=dtype)
return RqmcPoints(z=z, weights=weights)