Utilities¶
Helper functions and utilities for model visualization, I/O, and analysis.
Visualization¶
Visualize probabilistic circuit structures as graphs.
- spflow.utils.visualization.visualize(module, output_path, show_scope=True, show_shape=True, show_params=True, format='pdf', dpi=300, engine='dot', rankdir='BT', node_shape='box', skip_ops=True)[source]¶
Visualize a SPFlow module as a directed graph and save to file.
- Parameters:
module (
Module) – The root module to visualize.output_path (
str) – Path to save the visualization (without extension).show_scope (
bool) – Whether to display scope information in node labels.show_shape (
bool) – Whether to display shape information (D: out_features, C: out_channels) in node labels.show_params (
bool) – Whether to display parameter count in node labels. Parameter counts are formatted with K/M suffixes for readability (e.g., “1.2K”, “3.5M”).format (
str) – Output format - ‘png’, ‘pdf’, ‘svg’, ‘dot’, ‘plain’, or ‘canon’. Text-based formats are useful for viewing graph structure in the terminal: - ‘dot’/’canon’: Graphviz DOT language source code - ‘plain’: Simple text format with node positions and edgesdpi (
int) – DPI for rasterized formats (png). Applied via graph-level dpi attribute.engine (
str) – Graphviz layout engine. Options: - ‘dot’ (default): Hierarchical top-down layout, best for directed acyclic graphs - ‘dot-lr’: Hierarchical left-right layout (automatically sets rankdir=’LR’) - ‘neato’: Spring model layout (force-directed) - ‘fdp’: Force-directed placement, similar to neato - ‘circo’: Circular layout - ‘twopi’: Radial layout - ‘osage’: Clustered layoutrankdir (
Literal['TB','LR','BT','RL']) – Direction of graph layout (only used with ‘dot’ and ‘dot-lr’ engines): - ‘TB’: Top to bottom (default) - ‘LR’: Left to right - ‘BT’: Bottom to top - ‘RL’: Right to leftnode_shape (
str) – Shape of nodes. Common options: ‘box’ (default), ‘circle’, ‘ellipse’, ‘diamond’, ‘triangle’, ‘plaintext’, ‘record’, ‘Mrecord’.skip_ops (
bool) – Whether to skip ops modules in visualization (Cat, Split, SplitConsecutive, SplitInterleaved). These are pass-through modules that are bypassed and their inputs connected directly to parent. Defaults to True.
- Return type:
- Returns:
None. The visualization is saved to the specified output path.
Model I/O¶
Save and load models from disk.
Cache¶
Utilities for caching intermediate computations to speed up inference.
- class spflow.utils.cache.Cache[source]¶
Bases:
objectThread-safe cache with per-method-type locking and weak key references.
Uses WeakKeyDictionary to store cached values keyed by module instances, allowing garbage collection when modules are no longer referenced elsewhere.
- extras¶
Mutable dictionary for user-defined, per-traversal state. This can be used to pass custom information through recursive module calls without changing public method signatures.
Method Replacement¶
Temporarily replace module methods for testing or experimentation.
- spflow.utils.replace.replace(method_ref, replacement_func)[source]¶
Temporarily replace a class method with a custom implementation.
Automatically detects and preserves decorators (like @cached) by re-applying them to the replacement function. Works at the class level, affecting all instances of the class.
- Parameters:
method_ref (
Callable) – Reference to the method to replace (e.g., Sum.log_likelihood). This should be an unbound method (accessed via the class, not instance).replacement_func (
Callable) – The function to use as replacement. Must have a compatible signature with the original method (including ‘self’ as first parameter).
- Yields:
None.
Example
def my_custom_ll(self, data, cache=None): # Custom implementation return torch.ones(len(data)) model = Product(Sum(Product(Normal(...)))) # Normal inference model.log_likelihood(data) # Use custom implementation for Sum modules with replace(Sum.log_likelihood, my_custom_ll): model.log_likelihood(data) # Sum instances now use my_custom_ll
- Raises:
ValueError – If the class cannot be inferred from the method reference.
TypeError – If method_ref is not a valid method reference.