"""Construction-only DSL for building SPFlow circuits.
This module provides a small, non-invasive expression layer for writing examples
with algebraic syntax while keeping the core `spflow.modules` API unchanged.
The DSL is intentionally minimal:
- Products: `term(A) * term(B)`
- Weighted sums (mixtures): `0.4 * term(A) + 0.6 * term(B)`
To obtain an actual `Module`, call `.build()` on the resulting expression.
Notes:
- Weights must be provided for sums; `term(A) + term(B)` is intentionally disallowed.
- Weighted sums are restricted to terms with `out_shape.channels == 1` for simplicity.
For convenience in docs/examples, `dsl()` can temporarily enable operator overloads on
`spflow.modules.module.Module` within a context manager, so that expressions like
`0.4 * Normal(0) * Normal(1) + 0.6 * Normal(0) * Normal(1)` work without wrapping leaves.
"""
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
import torch
from einops import repeat
from spflow.exceptions import (
InvalidParameterCombinationError,
InvalidParameterError,
InvalidWeightsError,
ScopeError,
ShapeError,
)
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.products.product import Product
from spflow.modules.sums.sum import Sum
[docs]
@runtime_checkable
class Buildable(Protocol):
"""Protocol for objects that can build a `Module`."""
[docs]
def build(self) -> Module:
...
[docs]
def as_expr(value: Module | Buildable) -> Buildable:
"""Convert a `Module` or DSL expression to a DSL expression."""
if isinstance(value, Module):
return Term(value)
if isinstance(value, Buildable):
return value
raise InvalidParameterError(f"Expected a Module or DSL expression, got {type(value)}.")
[docs]
def term(module: Module) -> "Term":
"""Wrap a `Module` as a DSL term."""
return Term(module)
[docs]
def w(weight: float, value: Module | Buildable) -> "WeightedExpr":
"""Convenience helper to create a weighted term."""
return WeightedExpr(weight=weight, expr=as_expr(value))
[docs]
def build(value: Module | Buildable) -> Module:
"""Build a concrete `Module` from a DSL expression (or pass through `Module`)."""
if isinstance(value, Module):
return value
return value.build()
[docs]
@dataclass(frozen=True)
class Term(Buildable):
"""Leaf expression node that wraps a concrete `Module`."""
module: Module
[docs]
def build(self) -> Module:
return self.module
def __mul__(self, other: object) -> "ProductExpr | WeightedExpr":
if isinstance(other, (int, float)):
return WeightedExpr(weight=float(other), expr=self)
return ProductExpr([self, as_expr(other)]) # type: ignore[arg-type]
def __rmul__(self, weight: float) -> "WeightedExpr":
if not isinstance(weight, (int, float)):
raise InvalidParameterError(f"Expected numeric weight, got {type(weight)}.")
return WeightedExpr(weight=float(weight), expr=self)
def __add__(self, other: object) -> "SumExpr": # pragma: no cover
raise InvalidParameterError(
"Unweighted '+' is not supported in the DSL. Use 'a * term(x) + b * term(y)'."
)
[docs]
@dataclass(frozen=True)
class ProductExpr(Buildable):
"""Product of one or more sub-expressions."""
factors: list[Buildable]
def __mul__(self, other: object) -> "ProductExpr | WeightedExpr":
if isinstance(other, (int, float)):
return WeightedExpr(weight=float(other), expr=self)
return ProductExpr([*self.factors, as_expr(other)]) # type: ignore[arg-type]
def __rmul__(self, weight: float) -> "WeightedExpr":
if not isinstance(weight, (int, float)):
raise InvalidParameterError(f"Expected numeric weight, got {type(weight)}.")
return WeightedExpr(weight=float(weight), expr=self)
def __add__(self, other: object) -> "SumExpr": # pragma: no cover
raise InvalidParameterError(
"Unweighted '+' is not supported in the DSL. Use 'a * term(x) + b * term(y)'."
)
[docs]
def build(self) -> Module:
modules = [factor.build() for factor in self.factors]
_validate_product_modules(modules)
return Product(inputs=modules)
[docs]
@dataclass(frozen=True)
class WeightedExpr:
"""A weighted expression term used as an input to mixtures."""
weight: float
expr: Buildable
def __post_init__(self) -> None:
if not isinstance(self.weight, (int, float)):
raise InvalidParameterError(f"Weight must be numeric, got {type(self.weight)}.")
if not torch.isfinite(torch.as_tensor(float(self.weight))):
raise InvalidWeightsError("Weight must be finite.")
if float(self.weight) <= 0.0:
raise InvalidWeightsError("Weights must be strictly positive.")
[docs]
def build(self) -> Module:
raise InvalidParameterError(
"A weighted term cannot be built directly. Combine weighted terms with '+' to form a mixture."
)
def __add__(self, other: "WeightedExpr | SumExpr") -> "SumExpr":
if isinstance(other, WeightedExpr):
return SumExpr([(self.weight, self.expr), (other.weight, other.expr)])
if isinstance(other, SumExpr):
return SumExpr([(self.weight, self.expr), *other.terms])
raise InvalidParameterError(
f"Can only add a weighted term to another weighted term or mixture, got {type(other)}."
)
def __radd__(self, other: object) -> "SumExpr":
if isinstance(other, SumExpr):
return SumExpr([*other.terms, (self.weight, self.expr)])
return NotImplemented # type: ignore[return-value]
def __mul__(self, other: Module | Buildable) -> "WeightedExpr":
"""Apply this weight to a product expression.
This enables compact example syntax like ``0.4 * Normal(0) * Normal(1)``.
"""
return WeightedExpr(weight=self.weight, expr=ProductExpr([self.expr, as_expr(other)]))
[docs]
@dataclass(frozen=True)
class SumExpr(Buildable):
"""A weighted mixture of expressions.
Terms are stored as (weight, expr) pairs and normalized on build.
"""
terms: list[tuple[float, Buildable]]
def __post_init__(self) -> None:
if len(self.terms) < 2:
raise InvalidParameterError("A mixture requires at least two weighted terms.")
for weight, _ in self.terms:
if float(weight) <= 0.0:
raise InvalidWeightsError("Weights must be strictly positive.")
def __add__(self, other: WeightedExpr | "SumExpr") -> "SumExpr":
if isinstance(other, WeightedExpr):
return SumExpr([*self.terms, (other.weight, other.expr)])
if isinstance(other, SumExpr):
return SumExpr([*self.terms, *other.terms])
raise InvalidParameterError(
f"Can only add a weighted term or mixture to a mixture, got {type(other)}."
)
def __mul__(self, other: object) -> ProductExpr | WeightedExpr:
if isinstance(other, (int, float)):
return WeightedExpr(weight=float(other), expr=self)
return ProductExpr([self, as_expr(other)]) # type: ignore[arg-type]
def __rmul__(self, weight: float) -> WeightedExpr:
if not isinstance(weight, (int, float)):
raise InvalidParameterError(f"Expected numeric weight, got {type(weight)}.")
return WeightedExpr(weight=float(weight), expr=self)
[docs]
def build(self) -> Module:
modules = [expr.build() for _, expr in self.terms]
weights = [float(w) for w, _ in self.terms]
_validate_sum_modules(modules)
weights_tensor = _make_sum_weights(
weights=weights,
features=modules[0].out_shape.features,
repetitions=modules[0].out_shape.repetitions,
device=modules[0].device,
dtype=torch.get_default_dtype(),
)
return Sum(inputs=modules, weights=weights_tensor)
def _validate_product_modules(modules: list[Module]) -> None:
if len(modules) < 2:
raise InvalidParameterError("Product requires at least two factors.")
scopes = [m.scope for m in modules]
if not Scope.all_pairwise_disjoint(scopes):
raise ScopeError("Product factors must have disjoint scopes.")
channels = {m.out_shape.channels for m in modules}
if len(channels) != 1:
raise ShapeError(f"Product factors must have the same out_channels; got {sorted(channels)}.")
repetitions = {m.out_shape.repetitions for m in modules}
if len(repetitions) != 1:
raise ShapeError(f"Product factors must have the same num_repetitions; got {sorted(repetitions)}.")
devices = {str(m.device) for m in modules}
if len(devices) != 1:
raise InvalidParameterCombinationError(
f"Product factors must be on the same device; got {sorted(devices)}."
)
def _validate_sum_modules(modules: list[Module]) -> None:
if len(modules) < 2:
raise InvalidParameterError("Sum requires at least two terms.")
scopes = [m.scope for m in modules]
if not Scope.all_equal(scopes):
raise ScopeError("Sum terms must have identical scopes.")
features = {m.out_shape.features for m in modules}
if len(features) != 1:
raise ShapeError(f"Sum terms must have the same number of features; got {sorted(features)}.")
channels = {m.out_shape.channels for m in modules}
if channels != {1}:
raise ShapeError(
"Sum DSL only supports terms with out_shape.channels == 1. "
f"Got out_channels: {sorted(channels)}."
)
repetitions = {m.out_shape.repetitions for m in modules}
if len(repetitions) != 1:
raise ShapeError(f"Sum terms must have the same num_repetitions; got {sorted(repetitions)}.")
devices = {str(m.device) for m in modules}
if len(devices) != 1:
raise InvalidParameterCombinationError(
f"Sum terms must be on the same device; got {sorted(devices)}."
)
def _make_sum_weights(
*,
weights: list[float],
features: int,
repetitions: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
in_channels = len(weights)
raw = torch.as_tensor(weights, dtype=dtype, device=device)
if raw.dim() != 1 or raw.shape[0] != in_channels:
raise ShapeError("Expected a 1D weight vector.")
if not torch.isfinite(raw).all():
raise InvalidWeightsError("Weights must be finite.")
if not torch.all(raw > 0):
raise InvalidWeightsError("Weights must be strictly positive.")
total = torch.sum(raw)
if not torch.isfinite(total) or float(total) <= 0.0:
raise InvalidWeightsError("Sum of weights must be finite and > 0.")
normalized = raw / total
# Sum expects weights of shape: (features, in_channels, out_channels, repetitions)
w = repeat(normalized, "ci -> f ci 1 r", f=features, r=repetitions)
return w
[docs]
@contextmanager
def dsl():
"""Temporarily enable DSL operator overloads on `Module`.
This is intended for documentation/examples. It monkeypatches operator methods on
`spflow.modules.module.Module` for the duration of the context manager and restores
the original methods afterward.
Within the context:
- `Module * Module` builds a `ProductExpr`
- `float * Module` and `Module * float` create a `WeightedExpr`
- `WeightedExpr + WeightedExpr (+ ...)` creates a `SumExpr`
- `Module + Module` remains disallowed (weights must be explicit)
"""
# Save originals (may be missing on base class).
sentinel = object()
orig_mul = getattr(Module, "__mul__", sentinel)
orig_rmul = getattr(Module, "__rmul__", sentinel)
orig_add = getattr(Module, "__add__", sentinel)
orig_radd = getattr(Module, "__radd__", sentinel)
def _dsl_mul(self: Module, other: object):
if isinstance(other, (int, float)):
return WeightedExpr(weight=float(other), expr=Term(self))
if isinstance(other, WeightedExpr):
return WeightedExpr(weight=other.weight, expr=ProductExpr([Term(self), other.expr]))
if isinstance(other, Module):
return ProductExpr([Term(self), Term(other)])
if isinstance(other, Buildable):
return ProductExpr([Term(self), other])
return NotImplemented
def _dsl_rmul(self: Module, other: object):
if isinstance(other, (int, float)):
return WeightedExpr(weight=float(other), expr=Term(self))
return NotImplemented
def _dsl_add(self: Module, other: object):
raise InvalidParameterError("Unweighted '+' is not supported in the DSL. Use 'a * X + b * Y'.")
def _dsl_radd(self: Module, other: object):
return NotImplemented
setattr(Module, "__mul__", _dsl_mul)
setattr(Module, "__rmul__", _dsl_rmul)
setattr(Module, "__add__", _dsl_add)
setattr(Module, "__radd__", _dsl_radd)
try:
yield
finally:
# Restore originals.
if orig_mul is sentinel:
delattr(Module, "__mul__")
else:
setattr(Module, "__mul__", orig_mul)
if orig_rmul is sentinel:
delattr(Module, "__rmul__")
else:
setattr(Module, "__rmul__", orig_rmul)
if orig_add is sentinel:
delattr(Module, "__add__")
else:
setattr(Module, "__add__", orig_add)
if orig_radd is sentinel:
delattr(Module, "__radd__")
else:
setattr(Module, "__radd__", orig_radd)