"""Graph visualization utilities for SPFlow modules.
This module provides functions to visualize SPFlow module graphs using pydot and graphviz.
"""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Literal
from spflow.exceptions import GraphvizError, OptionalDependencyError
try:
import pydot
from pydot.exceptions import PydotException
except ImportError as e:
raise OptionalDependencyError(
"The 'pydot' package is required for visualization functionality.\n\n"
"To install pydot and graphviz dependencies:\n"
" 1. Install the Graphviz system dependency:\n"
" - On macOS: brew install graphviz\n"
" - On Ubuntu/Debian: sudo apt-get install graphviz\n"
" - On Windows: Download from https://graphviz.org/download/\n"
" 2. Install pydot Python package:\n"
" - pip install pydot\n"
" - OR install SPFlow with visualization extras: pip install spflow[viz]\n\n"
"For more details, see the README.md file in the SPFlow repository."
) from e
if TYPE_CHECKING:
from spflow.modules.module import Module
from spflow.modules.leaves.leaf import LeafModule
from spflow.modules.ops.cat import Cat
from spflow.modules.ops.split import Split
from spflow.modules.ops.split_by_index import SplitByIndex
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
class Color(str, Enum):
"""Tab10 colormap colors for module visualization.
Uses matplotlib's tab10 colormap for consistent, perceptually uniform colors.
Derived from: https://matplotlib.org/stable/users/explain/colors/colormaps.html
"""
# Tab10 colors (indexed 0-9)
BLUE = "#1f77b4" # tab10[0] - Sum-related modules
ORANGE = "#ff7f0e" # tab10[1] - Product-related modules
GREEN = "#2ca02c" # tab10[2] - Leaf modules
RED = "#d62728" # tab10[3] - RatSPN
PURPLE = "#9467bd" # tab10[4] - Split-related modules
BROWN = "#8c564b" # tab10[5] - Factorize
PINK = "#e377c2" # tab10[6] - Cat
GRAY = "#7f7f7f" # tab10[7] - Default/Unknown types
# Ops modules to skip in visualization (pass-through/helper modules)
# When these modules are encountered, they are bypassed and their inputs are connected
# directly to the parent module
SKIP_OPS = {Cat, Split, SplitByIndex, SplitConsecutive, SplitInterleaved}
def _format_param_count(count: int) -> str:
"""Format parameter count with K/M suffixes for readability.
Args:
count: Number of parameters.
Returns:
Formatted string (e.g., "1.2K", "3.5M", "42").
"""
if count < 1000:
return str(count)
elif count < 1_000_000:
return f"{count / 1000:.1f}K"
else:
return f"{count / 1_000_000:.1f}M"
def _count_parameters(module: Module) -> int:
"""Count parameters for a module.
For leaves modules, counts all parameters (including the distribution child module).
For other modules, counts only parameters directly owned by this module (excluding children).
Args:
module: The module to count parameters for.
Returns:
Number of parameters. For leaves modules, includes child distribution parameters.
For other modules, only direct parameters.
"""
# For leaves modules, include all parameters (including distribution child)
if isinstance(module, LeafModule):
return sum(p.numel() for p in module.parameters())
# For other modules, only count direct parameters
return sum(p.numel() for p in module.parameters(recurse=False))
[docs]
def visualize(
module: Module,
output_path: str,
show_scope: bool = True,
show_shape: bool = True,
show_params: bool = True,
format: str = "pdf",
dpi: int = 300,
engine: str = "dot",
rankdir: Literal["TB", "LR", "BT", "RL"] = "BT",
node_shape: str = "box",
skip_ops: bool = True,
) -> None:
"""Visualize a SPFlow module as a directed graph and save to file.
Args:
module: The root module to visualize.
output_path: Path to save the visualization (without extension).
show_scope: Whether to display scope information in node labels.
show_shape: Whether to display shape information (D: out_features, C: out_channels) in node labels.
show_params: Whether to display parameter count in node labels. Parameter counts are formatted
with K/M suffixes for readability (e.g., "1.2K", "3.5M").
format: Output format - 'png', 'pdf', 'svg', 'dot', 'plain', or 'canon'.
Text-based formats are useful for viewing graph structure in the terminal:
- 'dot'/'canon': Graphviz DOT language source code
- 'plain': Simple text format with node positions and edges
dpi: DPI for rasterized formats (png). Applied via graph-level dpi attribute.
engine: Graphviz layout engine. Options:
- 'dot' (default): Hierarchical top-down layout, best for directed acyclic graphs
- 'dot-lr': Hierarchical left-right layout (automatically sets rankdir='LR')
- 'neato': Spring model layout (force-directed)
- 'fdp': Force-directed placement, similar to neato
- 'circo': Circular layout
- 'twopi': Radial layout
- 'osage': Clustered layout
rankdir: Direction of graph layout (only used with 'dot' and 'dot-lr' engines):
- 'TB': Top to bottom (default)
- 'LR': Left to right
- 'BT': Bottom to top
- 'RL': Right to left
node_shape: Shape of nodes. Common options: 'box' (default), 'circle', 'ellipse',
'diamond', 'triangle', 'plaintext', 'record', 'Mrecord'.
skip_ops: Whether to skip ops modules in visualization (Cat, Split, SplitConsecutive, SplitInterleaved).
These are pass-through modules that are bypassed and their inputs connected directly to parent.
Defaults to True.
Returns:
None. The visualization is saved to the specified output path.
"""
# Handle special engine variants
if engine == "dot-lr":
engine = "dot"
rankdir = "LR"
# Create the pydot graph
graph = pydot.Dot(graph_type="digraph", rankdir=rankdir, dpi=str(dpi))
# Set graph attributes for better aesthetics
graph.set_graph_defaults(
fontname="Helvetica",
fontsize="11",
nodesep="0.5",
ranksep="0.8",
)
# Set node defaults
graph.set_node_defaults(
shape=node_shape,
style="rounded,filled",
fillcolor="white",
fontname="Helvetica",
fontsize="11",
penwidth="2.5",
margin="0.15,0.08", # Horizontal, vertical padding
)
# Set edge defaults
graph.set_edge_defaults(
color="#333333",
penwidth="2.0",
arrowsize="0.8",
)
# Build the graph
_build_graph(
module,
graph,
show_scope=show_scope,
show_shape=show_shape,
show_params=show_params,
skip_ops=skip_ops,
)
# Generate output file
output_file = f"{output_path}.{format}"
# Write output using the specified engine
try:
match format:
case "png":
graph.write_png(output_file, prog=engine)
case "pdf":
graph.write_pdf(output_file, prog=engine)
case "svg":
graph.write_svg(output_file, prog=engine)
case "dot":
graph.write_dot(output_file, prog=engine)
case "plain":
graph.write_plain(output_file, prog=engine)
case "canon":
graph.write(output_file, format="canon", prog=engine)
case _:
raise ValueError(
f"Unsupported format: {format}. Supported formats: png, pdf, svg, dot, plain, canon"
)
except FileNotFoundError as e:
# This error occurs when Graphviz is not installed or not in PATH
raise GraphvizError(
f"Graphviz executable '{engine}' not found. This usually means Graphviz is not installed or not in your system PATH."
) from e
except (AssertionError, OSError, PydotException) as e:
# Catch errors from pydot/graphviz execution
raise GraphvizError(
f"Error executing Graphviz: {str(e)}\n\n"
f"This error typically indicates a problem with your Graphviz installation."
) from e
def _build_graph(
module: Module,
graph: pydot.Dot,
show_scope: bool = False,
show_shape: bool = False,
show_params: bool = False,
visited: set | None = None,
parent_id: int | None = None,
skip_ops: bool = True,
) -> int | None:
"""Recursively build a pydot graph from a module tree.
Args:
module: Current module to add to the graph.
graph: pydot Dot graph to populate.
show_scope: Whether to include scope information in labels.
show_shape: Whether to include shape information in labels.
show_params: Whether to include parameter counts in labels.
visited: Set of module IDs already visited (to avoid duplicates).
parent_id: ID of the parent node (used when skipping modules).
skip_ops: Whether to skip ops modules in visualization.
Returns:
The node ID for the current module, or None if the module was skipped.
"""
from torch import nn
from spflow.modules.module import Module
if visited is None:
visited = set()
node_id = id(module)
# Check if this module should be skipped in the visualization
if skip_ops and SKIP_OPS and isinstance(module, tuple(SKIP_OPS)) and parent_id is not None:
# This is a pass-through module - skip it and connect its inputs directly to parent
# Check for input attribute (unified)
if hasattr(module, "inputs") and module.inputs is not None:
inputs = module.inputs
# Check for ModuleList (Cat, ElementwiseSum, BaseProduct)
if (
hasattr(inputs, "__iter__")
and not isinstance(inputs, (tuple, list))
and inputs.__class__.__name__ == "ModuleList"
):
# Convert to list for iteration
inputs_list = list(inputs)
for input_module in inputs_list:
child_id = _build_graph(
input_module, graph, show_scope, show_shape, show_params, visited, parent_id, skip_ops
)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(parent_id))
graph.add_edge(edge)
# Handle regular list (unlikely but possible)
elif isinstance(inputs, list):
for input_module in inputs:
child_id = _build_graph(
input_module, graph, show_scope, show_shape, show_params, visited, parent_id, skip_ops
)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(parent_id))
graph.add_edge(edge)
# Handle single Module
elif isinstance(inputs, Module):
# Skip Cat wrapper check? Original code didn't skip nested Cat here, it recursed.
# Just treat as single child.
child_id = _build_graph(
inputs, graph, show_scope, show_shape, show_params, visited, parent_id, skip_ops
)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(parent_id))
graph.add_edge(edge)
return None # Return None to indicate this module was skipped
# Skip if already visited
if node_id in visited:
return node_id
visited.add(node_id)
# Create node label
label = _get_module_label(module, show_scope=show_scope, show_shape=show_shape, show_params=show_params)
# Get color for this module type
color = _get_module_color(module)
# Add node to graph
node = pydot.Node(
str(node_id),
label=label,
color=color,
)
graph.add_node(node)
# Traverse inputs if they exist
# Check for input attribute (unified)
if hasattr(module, "inputs") and module.inputs is not None:
inputs = module.inputs
# Check for ModuleList (Cat, ElementwiseSum, BaseProduct)
if (
hasattr(inputs, "__iter__")
and not isinstance(inputs, (tuple, list))
and inputs.__class__.__name__ == "ModuleList"
):
# Convert to list for iteration
inputs_list = list(inputs)
for input_module in inputs_list:
child_id = _build_graph(
input_module,
graph,
show_scope,
show_shape,
show_params,
visited,
parent_id=node_id,
skip_ops=skip_ops,
)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(node_id))
graph.add_edge(edge)
# Handle regular list
elif isinstance(inputs, list):
for input_module in inputs:
child_id = _build_graph(
input_module,
graph,
show_scope,
show_shape,
show_params,
visited,
parent_id=node_id,
skip_ops=skip_ops,
)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(node_id))
graph.add_edge(edge)
# Handle single Module
elif isinstance(inputs, Module):
child_id = _build_graph(
inputs,
graph,
show_scope,
show_shape,
show_params,
visited,
parent_id=node_id,
skip_ops=skip_ops,
)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(node_id))
graph.add_edge(edge)
# Special handling for RatSPN: traverse through root_node
if hasattr(module, "root_node"):
child_id = _build_graph(
module.root_node,
graph,
show_scope,
show_shape,
show_params,
visited,
parent_id=node_id,
skip_ops=skip_ops,
)
# Only add edge if child was actually added to graph (not skipped)
if child_id is not None:
edge = pydot.Edge(str(child_id), str(node_id))
graph.add_edge(edge)
return node_id
def _format_scope_string(scopes: list[int]) -> str:
"""Format a list of scope indices, using ranges for consecutive sequences.
Consecutive sequences of 3 or more indices are represented as ranges (e.g., "0...4").
Shorter consecutive sequences are listed individually.
Args:
scopes: List of scope indices.
Returns:
Formatted scope string with ranges for consecutive sequences.
"""
if not scopes:
return ""
# Sort and deduplicate
sorted_scopes = sorted(set(scopes))
result = []
i = 0
while i < len(sorted_scopes):
start = sorted_scopes[i]
end = start
# Find the end of the consecutive sequence
while i + 1 < len(sorted_scopes) and sorted_scopes[i + 1] == sorted_scopes[i] + 1:
i += 1
end = sorted_scopes[i]
# Determine if we should use a range (4 or more consecutive)
sequence_length = end - start + 1
if sequence_length >= 3:
# Use range for 3+ consecutive numbers
result.append(f"{start}...{end}")
else:
# List individually for 1-3 numbers
result.append(", ".join(str(x) for x in range(start, end + 1)))
i += 1
return ", ".join(result)
def _build_vis_label(
module: Module,
show_shape: bool = True,
show_params: bool = False,
show_scope: bool = False,
) -> str:
"""Build a visualization label for a module showing shape, params, and scope.
All labels (Out, In, Scope, Params) are right-aligned to the same width
for consistent display with monospace font.
Order: Out, In, Scope, Params
Args:
module: The module to generate a label for.
show_shape: Whether to include shape info (Out/In with D, C, R values).
show_params: Whether to include parameter count.
show_scope: Whether to include scope information.
Returns:
A multi-line label string with properly aligned information.
"""
# Collect all labels we'll display to calculate max width
all_labels = []
# Shape labels (Out first, then In)
shapes = []
shape_labels = []
if show_shape:
# Out shape first
shapes.append((module.out_shape.features, module.out_shape.channels, module.out_shape.repetitions))
shape_labels.append("Out")
all_labels.append("Out")
# In shape second (if available)
if module.in_shape is not None:
shapes.append((module.in_shape.features, module.in_shape.channels, module.in_shape.repetitions))
shape_labels.append("In")
all_labels.append("In")
# Check scope (before params in the new order)
scope_str = ""
if show_scope:
scope_str = _format_scope_string(sorted(module.scope.query))
all_labels.append("Scope")
# Check params (last in the new order)
param_count = 0
if show_params:
param_count = _count_parameters(module)
if param_count > 0:
all_labels.append("Params")
# Calculate max label width across all labels
max_label_width = max(len(label) for label in all_labels) if all_labels else 0
lines = []
# Shape information (Out/In)
if show_shape and shapes:
# Calculate max width for each parameter value across all shapes
max_d_width = max(len(str(s[0])) for s in shapes)
max_c_width = max(len(str(s[1])) for s in shapes)
max_r_width = max(len(str(s[2])) for s in shapes)
# Build aligned shape lines
for label, (d, c, r) in zip(shape_labels, shapes):
# Right-align the label
padded_label = label.rjust(max_label_width)
# Left-pad each parameter value to align with the widest
d_str = str(d).ljust(max_d_width)
c_str = str(c).ljust(max_c_width)
r_str = str(r).ljust(max_r_width)
lines.append(f"{padded_label}: D={d_str} C={c_str} R={r_str}")
# Scope information (before params)
if show_scope:
padded_label = "Scope".rjust(max_label_width)
lines.append(f"{padded_label}: {scope_str}")
# Parameter count (last)
if show_params and param_count > 0:
formatted_count = _format_param_count(param_count)
padded_label = "Params".rjust(max_label_width)
lines.append(f"{padded_label}: {formatted_count}")
# Check for extra visualization info from the module
extra_info = module._extra_vis_info()
if extra_info is not None:
lines.append(extra_info)
return "\n".join(lines)
def _get_module_label(
module: Module, show_scope: bool = False, show_shape: bool = False, show_params: bool = False
) -> str:
"""Generate a label for a module node.
Uses HTML labels with bold module names and monospace font for
proper alignment of shape, params, and scope information.
Uses HTML table with left-aligned cells to preserve alignment.
Args:
module: The module to generate a label for.
show_scope: Whether to include scope information.
show_shape: Whether to include shape information.
show_params: Whether to include parameter count.
Returns:
An HTML label string for the module (compatible with pydot/graphviz).
"""
# Get the class name
class_name = module.__class__.__name__
# Build HTML table with left-aligned cells for proper alignment
# Using a table structure ensures the text is left-aligned within the node
rows = []
# Module name row (bold, centered as header)
rows.append(f'<TR><TD ALIGN="CENTER"><font face="Courier"><B>{class_name}</B></font></TD></TR>')
# Build visualization label
if show_shape or show_params or show_scope:
vis_label = _build_vis_label(
module,
show_shape=show_shape,
show_params=show_params,
show_scope=show_scope,
)
if vis_label:
# Each line becomes a left-aligned table row
for line in vis_label.split("\n"):
rows.append(f'<TR><TD ALIGN="LEFT"><font face="Courier">{line}</font></TD></TR>')
# Build table structure with no borders and minimal spacing
table = '<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">' + "".join(rows) + "</TABLE>"
# Wrap in angle brackets for HTML label
return "<" + table + ">"
def _get_module_color(module: Module) -> str:
"""Get the color for a module based on its type.
Uses matplotlib tab10 colormap for consistent, distinguishable colors.
Related module types share colors within groups.
Args:
module: The module to get a color for.
Returns:
A color string (hex) for the module type based on tab10 colormap.
"""
class_name = module.__class__.__name__
# Check if this is a leaves module (all leaves modules get the same color)
try:
from spflow.modules.leaves import LeafModule
if isinstance(module, LeafModule):
return Color.GREEN
except ImportError:
# Fallback to class name checking if LeafModule can't be imported
leaf_modules = {
"Normal",
"Categorical",
"Bernoulli",
"Poisson",
"Exponential",
"CondNormal",
"CondCategorical",
}
if class_name in leaf_modules:
return Color.GREEN
# Define color mapping for different module types using Color enum
color_map = {
# Sum modules
"Sum": Color.BLUE,
"ElementwiseSum": Color.BLUE,
"MixingLayer": Color.BLUE,
"LinsumLayer": Color.BLUE,
"EinsumLayer": Color.BLUE,
"SumConv": Color.BLUE,
# Product modules
"Product": Color.ORANGE,
"ElementwiseProduct": Color.ORANGE,
"OuterProduct": Color.ORANGE,
"ProdConv": Color.ORANGE,
"ConvPc": Color.ORANGE,
# Operations - Cat
"Cat": Color.PINK,
# Operations - Split
"Split": Color.PURPLE,
"SplitByIndex": Color.PURPLE,
"SplitConsecutive": Color.PURPLE,
"SplitInterleaved": Color.PURPLE,
# Operations - Factorize
"Factorize": Color.BROWN,
# RAT-SPN
"RatSPN": Color.RED,
}
# Return color if found, otherwise use GRAY as default
return color_map.get(class_name, Color.GRAY)