Utilities

Helper functions and utilities for model visualization, I/O, and analysis.

Visualization

Visualize probabilistic circuit structures as graphs.

spflow.utils.visualization.visualize(module, output_path, show_scope=True, show_shape=True, show_params=True, format='pdf', dpi=300, engine='dot', rankdir='BT', node_shape='box', skip_ops=True)[source]

Visualize a SPFlow module as a directed graph and save to file.

Parameters:
  • module (Module) – The root module to visualize.

  • output_path (str) – Path to save the visualization (without extension).

  • show_scope (bool) – Whether to display scope information in node labels.

  • show_shape (bool) – Whether to display shape information (D: out_features, C: out_channels) in node labels.

  • show_params (bool) – Whether to display parameter count in node labels. Parameter counts are formatted with K/M suffixes for readability (e.g., “1.2K”, “3.5M”).

  • format (str) – Output format - ‘png’, ‘pdf’, ‘svg’, ‘dot’, ‘plain’, or ‘canon’. Text-based formats are useful for viewing graph structure in the terminal: - ‘dot’/’canon’: Graphviz DOT language source code - ‘plain’: Simple text format with node positions and edges

  • dpi (int) – DPI for rasterized formats (png). Applied via graph-level dpi attribute.

  • engine (str) – Graphviz layout engine. Options: - ‘dot’ (default): Hierarchical top-down layout, best for directed acyclic graphs - ‘dot-lr’: Hierarchical left-right layout (automatically sets rankdir=’LR’) - ‘neato’: Spring model layout (force-directed) - ‘fdp’: Force-directed placement, similar to neato - ‘circo’: Circular layout - ‘twopi’: Radial layout - ‘osage’: Clustered layout

  • rankdir (Literal['TB', 'LR', 'BT', 'RL']) – Direction of graph layout (only used with ‘dot’ and ‘dot-lr’ engines): - ‘TB’: Top to bottom (default) - ‘LR’: Left to right - ‘BT’: Bottom to top - ‘RL’: Right to left

  • node_shape (str) – Shape of nodes. Common options: ‘box’ (default), ‘circle’, ‘ellipse’, ‘diamond’, ‘triangle’, ‘plaintext’, ‘record’, ‘Mrecord’.

  • skip_ops (bool) – Whether to skip ops modules in visualization (Cat, Split, SplitConsecutive, SplitInterleaved). These are pass-through modules that are bypassed and their inputs connected directly to parent. Defaults to True.

Return type:

None

Returns:

None. The visualization is saved to the specified output path.

Model I/O

Save and load models from disk.

spflow.utils.model_manager.save_model(model, path)[source]

Save an SPFlow model to disk using pickle serialization.

Parameters:
  • model (Module) – The SPFlow module to save.

  • path (str | bytes | PathLike) – File path where the model will be saved. The file will be created or overwritten if it already exists.

Return type:

None

spflow.utils.model_manager.load_model(path)[source]

Load an SPFlow model from disk using pickle deserialization.

Parameters:

path (str | bytes | PathLike) – File path of the saved model to load.

Return type:

Module

Returns:

The deserialized SPFlow module.

Cache

Utilities for caching intermediate computations to speed up inference.

class spflow.utils.cache.Cache[source]

Bases: object

Thread-safe cache with per-method-type locking and weak key references.

Uses WeakKeyDictionary to store cached values keyed by module instances, allowing garbage collection when modules are no longer referenced elsewhere.

extras

Mutable dictionary for user-defined, per-traversal state. This can be used to pass custom information through recursive module calls without changing public method signatures.

__init__()[source]

Initialize cache with per-method locks and storage.

get(method_name, module)[source]

Retrieve cached value for a module.

Parameters:
  • method_name (str) – Name of the cached method (e.g., “log_likelihood”).

  • module (Module) – Module instance to use as cache key.

Return type:

Any | None

Returns:

Cached value if present, None otherwise.

set(method_name, module, value)[source]

Store a value in cache for a module.

Parameters:
  • method_name (str) – Name of the cached method (e.g., “log_likelihood”).

  • module (Module) – Module instance to use as cache key.

  • value (Any) – Value to cache.

Return type:

None

Method Replacement

Temporarily replace module methods for testing or experimentation.

spflow.utils.replace.replace(method_ref, replacement_func)[source]

Temporarily replace a class method with a custom implementation.

Automatically detects and preserves decorators (like @cached) by re-applying them to the replacement function. Works at the class level, affecting all instances of the class.

Parameters:
  • method_ref (Callable) – Reference to the method to replace (e.g., Sum.log_likelihood). This should be an unbound method (accessed via the class, not instance).

  • replacement_func (Callable) – The function to use as replacement. Must have a compatible signature with the original method (including ‘self’ as first parameter).

Yields:

None.

Example

def my_custom_ll(self, data, cache=None):
    # Custom implementation
    return torch.ones(len(data))

model = Product(Sum(Product(Normal(...))))

# Normal inference
model.log_likelihood(data)

# Use custom implementation for Sum modules
with replace(Sum.log_likelihood, my_custom_ll):
    model.log_likelihood(data)  # Sum instances now use my_custom_ll
Raises:
  • ValueError – If the class cannot be inferred from the method reference.

  • TypeError – If method_ref is not a valid method reference.