Source code for spflow.utils.visualization

"""Graph visualization utilities for SPFlow modules.

This module provides functions to visualize SPFlow module graphs using pydot and graphviz.
"""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Literal

from spflow.exceptions import GraphvizError, OptionalDependencyError

try:
    import pydot
    from pydot.exceptions import PydotException
except ImportError as e:
    raise OptionalDependencyError(
        "The 'pydot' package is required for visualization functionality.\n\n"
        "To install pydot and graphviz dependencies:\n"
        "  1. Install the Graphviz system dependency:\n"
        "     - On macOS: brew install graphviz\n"
        "     - On Ubuntu/Debian: sudo apt-get install graphviz\n"
        "     - On Windows: Download from https://graphviz.org/download/\n"
        "  2. Install pydot Python package:\n"
        "     - pip install pydot\n"
        "     - OR install SPFlow with visualization extras: pip install spflow[viz]\n\n"
        "For more details, see the README.md file in the SPFlow repository."
    ) from e

if TYPE_CHECKING:
    from spflow.modules.module import Module

from spflow.modules.leaves.leaf import LeafModule
from spflow.modules.ops.cat import Cat
from spflow.modules.ops.split import Split
from spflow.modules.ops.split_by_index import SplitByIndex
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved


class Color(str, Enum):
    """Tab10 colormap colors for module visualization.

    Uses matplotlib's tab10 colormap for consistent, perceptually uniform colors.
    Derived from: https://matplotlib.org/stable/users/explain/colors/colormaps.html
    """

    # Tab10 colors (indexed 0-9)
    BLUE = "#1f77b4"  # tab10[0] - Sum-related modules
    ORANGE = "#ff7f0e"  # tab10[1] - Product-related modules
    GREEN = "#2ca02c"  # tab10[2] - Leaf modules
    RED = "#d62728"  # tab10[3] - RatSPN
    PURPLE = "#9467bd"  # tab10[4] - Split-related modules
    BROWN = "#8c564b"  # tab10[5] - Factorize
    PINK = "#e377c2"  # tab10[6] - Cat
    GRAY = "#7f7f7f"  # tab10[7] - Default/Unknown types


# Ops modules to skip in visualization (pass-through/helper modules)
# When these modules are encountered, they are bypassed and their inputs are connected
# directly to the parent module
SKIP_OPS = {Cat, Split, SplitByIndex, SplitConsecutive, SplitInterleaved}


def _format_param_count(count: int) -> str:
    """Format parameter count with K/M suffixes for readability.

    Args:
        count: Number of parameters.

    Returns:
        Formatted string (e.g., "1.2K", "3.5M", "42").
    """
    if count < 1000:
        return str(count)
    elif count < 1_000_000:
        return f"{count / 1000:.1f}K"
    else:
        return f"{count / 1_000_000:.1f}M"


def _count_parameters(module: Module) -> int:
    """Count parameters for a module.

    For leaves modules, counts all parameters (including the distribution child module).
    For other modules, counts only parameters directly owned by this module (excluding children).

    Args:
        module: The module to count parameters for.

    Returns:
        Number of parameters. For leaves modules, includes child distribution parameters.
        For other modules, only direct parameters.
    """
    # For leaves modules, include all parameters (including distribution child)
    if isinstance(module, LeafModule):
        return sum(p.numel() for p in module.parameters())
    # For other modules, only count direct parameters
    return sum(p.numel() for p in module.parameters(recurse=False))


[docs] def visualize( module: Module, output_path: str, show_scope: bool = True, show_shape: bool = True, show_params: bool = True, format: str = "pdf", dpi: int = 300, engine: str = "dot", rankdir: Literal["TB", "LR", "BT", "RL"] = "BT", node_shape: str = "box", skip_ops: bool = True, ) -> None: """Visualize a SPFlow module as a directed graph and save to file. Args: module: The root module to visualize. output_path: Path to save the visualization (without extension). show_scope: Whether to display scope information in node labels. show_shape: Whether to display shape information (D: out_features, C: out_channels) in node labels. show_params: Whether to display parameter count in node labels. Parameter counts are formatted with K/M suffixes for readability (e.g., "1.2K", "3.5M"). format: Output format - 'png', 'pdf', 'svg', 'dot', 'plain', or 'canon'. Text-based formats are useful for viewing graph structure in the terminal: - 'dot'/'canon': Graphviz DOT language source code - 'plain': Simple text format with node positions and edges dpi: DPI for rasterized formats (png). Applied via graph-level dpi attribute. engine: Graphviz layout engine. Options: - 'dot' (default): Hierarchical top-down layout, best for directed acyclic graphs - 'dot-lr': Hierarchical left-right layout (automatically sets rankdir='LR') - 'neato': Spring model layout (force-directed) - 'fdp': Force-directed placement, similar to neato - 'circo': Circular layout - 'twopi': Radial layout - 'osage': Clustered layout rankdir: Direction of graph layout (only used with 'dot' and 'dot-lr' engines): - 'TB': Top to bottom (default) - 'LR': Left to right - 'BT': Bottom to top - 'RL': Right to left node_shape: Shape of nodes. Common options: 'box' (default), 'circle', 'ellipse', 'diamond', 'triangle', 'plaintext', 'record', 'Mrecord'. skip_ops: Whether to skip ops modules in visualization (Cat, Split, SplitConsecutive, SplitInterleaved). These are pass-through modules that are bypassed and their inputs connected directly to parent. Defaults to True. Returns: None. The visualization is saved to the specified output path. """ # Handle special engine variants if engine == "dot-lr": engine = "dot" rankdir = "LR" # Create the pydot graph graph = pydot.Dot(graph_type="digraph", rankdir=rankdir, dpi=str(dpi)) # Set graph attributes for better aesthetics graph.set_graph_defaults( fontname="Helvetica", fontsize="11", nodesep="0.5", ranksep="0.8", ) # Set node defaults graph.set_node_defaults( shape=node_shape, style="rounded,filled", fillcolor="white", fontname="Helvetica", fontsize="11", penwidth="2.5", margin="0.15,0.08", # Horizontal, vertical padding ) # Set edge defaults graph.set_edge_defaults( color="#333333", penwidth="2.0", arrowsize="0.8", ) # Build the graph _build_graph( module, graph, show_scope=show_scope, show_shape=show_shape, show_params=show_params, skip_ops=skip_ops, ) # Generate output file output_file = f"{output_path}.{format}" # Write output using the specified engine try: match format: case "png": graph.write_png(output_file, prog=engine) case "pdf": graph.write_pdf(output_file, prog=engine) case "svg": graph.write_svg(output_file, prog=engine) case "dot": graph.write_dot(output_file, prog=engine) case "plain": graph.write_plain(output_file, prog=engine) case "canon": graph.write(output_file, format="canon", prog=engine) case _: raise ValueError( f"Unsupported format: {format}. Supported formats: png, pdf, svg, dot, plain, canon" ) except FileNotFoundError as e: # This error occurs when Graphviz is not installed or not in PATH raise GraphvizError( f"Graphviz executable '{engine}' not found. This usually means Graphviz is not installed or not in your system PATH." ) from e except (AssertionError, OSError, PydotException) as e: # Catch errors from pydot/graphviz execution raise GraphvizError( f"Error executing Graphviz: {str(e)}\n\n" f"This error typically indicates a problem with your Graphviz installation." ) from e
def _build_graph( module: Module, graph: pydot.Dot, show_scope: bool = False, show_shape: bool = False, show_params: bool = False, visited: set | None = None, parent_id: int | None = None, skip_ops: bool = True, ) -> int | None: """Recursively build a pydot graph from a module tree. Args: module: Current module to add to the graph. graph: pydot Dot graph to populate. show_scope: Whether to include scope information in labels. show_shape: Whether to include shape information in labels. show_params: Whether to include parameter counts in labels. visited: Set of module IDs already visited (to avoid duplicates). parent_id: ID of the parent node (used when skipping modules). skip_ops: Whether to skip ops modules in visualization. Returns: The node ID for the current module, or None if the module was skipped. """ from torch import nn from spflow.modules.module import Module if visited is None: visited = set() node_id = id(module) # Check if this module should be skipped in the visualization if skip_ops and SKIP_OPS and isinstance(module, tuple(SKIP_OPS)) and parent_id is not None: # This is a pass-through module - skip it and connect its inputs directly to parent # Check for input attribute (unified) if hasattr(module, "inputs") and module.inputs is not None: inputs = module.inputs # Check for ModuleList (Cat, ElementwiseSum, BaseProduct) if ( hasattr(inputs, "__iter__") and not isinstance(inputs, (tuple, list)) and inputs.__class__.__name__ == "ModuleList" ): # Convert to list for iteration inputs_list = list(inputs) for input_module in inputs_list: child_id = _build_graph( input_module, graph, show_scope, show_shape, show_params, visited, parent_id, skip_ops ) if child_id is not None: edge = pydot.Edge(str(child_id), str(parent_id)) graph.add_edge(edge) # Handle regular list (unlikely but possible) elif isinstance(inputs, list): for input_module in inputs: child_id = _build_graph( input_module, graph, show_scope, show_shape, show_params, visited, parent_id, skip_ops ) if child_id is not None: edge = pydot.Edge(str(child_id), str(parent_id)) graph.add_edge(edge) # Handle single Module elif isinstance(inputs, Module): # Skip Cat wrapper check? Original code didn't skip nested Cat here, it recursed. # Just treat as single child. child_id = _build_graph( inputs, graph, show_scope, show_shape, show_params, visited, parent_id, skip_ops ) if child_id is not None: edge = pydot.Edge(str(child_id), str(parent_id)) graph.add_edge(edge) return None # Return None to indicate this module was skipped # Skip if already visited if node_id in visited: return node_id visited.add(node_id) # Create node label label = _get_module_label(module, show_scope=show_scope, show_shape=show_shape, show_params=show_params) # Get color for this module type color = _get_module_color(module) # Add node to graph node = pydot.Node( str(node_id), label=label, color=color, ) graph.add_node(node) # Traverse inputs if they exist # Check for input attribute (unified) if hasattr(module, "inputs") and module.inputs is not None: inputs = module.inputs # Check for ModuleList (Cat, ElementwiseSum, BaseProduct) if ( hasattr(inputs, "__iter__") and not isinstance(inputs, (tuple, list)) and inputs.__class__.__name__ == "ModuleList" ): # Convert to list for iteration inputs_list = list(inputs) for input_module in inputs_list: child_id = _build_graph( input_module, graph, show_scope, show_shape, show_params, visited, parent_id=node_id, skip_ops=skip_ops, ) if child_id is not None: edge = pydot.Edge(str(child_id), str(node_id)) graph.add_edge(edge) # Handle regular list elif isinstance(inputs, list): for input_module in inputs: child_id = _build_graph( input_module, graph, show_scope, show_shape, show_params, visited, parent_id=node_id, skip_ops=skip_ops, ) if child_id is not None: edge = pydot.Edge(str(child_id), str(node_id)) graph.add_edge(edge) # Handle single Module elif isinstance(inputs, Module): child_id = _build_graph( inputs, graph, show_scope, show_shape, show_params, visited, parent_id=node_id, skip_ops=skip_ops, ) if child_id is not None: edge = pydot.Edge(str(child_id), str(node_id)) graph.add_edge(edge) # Special handling for RatSPN: traverse through root_node if hasattr(module, "root_node"): child_id = _build_graph( module.root_node, graph, show_scope, show_shape, show_params, visited, parent_id=node_id, skip_ops=skip_ops, ) # Only add edge if child was actually added to graph (not skipped) if child_id is not None: edge = pydot.Edge(str(child_id), str(node_id)) graph.add_edge(edge) return node_id def _format_scope_string(scopes: list[int]) -> str: """Format a list of scope indices, using ranges for consecutive sequences. Consecutive sequences of 3 or more indices are represented as ranges (e.g., "0...4"). Shorter consecutive sequences are listed individually. Args: scopes: List of scope indices. Returns: Formatted scope string with ranges for consecutive sequences. """ if not scopes: return "" # Sort and deduplicate sorted_scopes = sorted(set(scopes)) result = [] i = 0 while i < len(sorted_scopes): start = sorted_scopes[i] end = start # Find the end of the consecutive sequence while i + 1 < len(sorted_scopes) and sorted_scopes[i + 1] == sorted_scopes[i] + 1: i += 1 end = sorted_scopes[i] # Determine if we should use a range (4 or more consecutive) sequence_length = end - start + 1 if sequence_length >= 3: # Use range for 3+ consecutive numbers result.append(f"{start}...{end}") else: # List individually for 1-3 numbers result.append(", ".join(str(x) for x in range(start, end + 1))) i += 1 return ", ".join(result) def _build_vis_label( module: Module, show_shape: bool = True, show_params: bool = False, show_scope: bool = False, ) -> str: """Build a visualization label for a module showing shape, params, and scope. All labels (Out, In, Scope, Params) are right-aligned to the same width for consistent display with monospace font. Order: Out, In, Scope, Params Args: module: The module to generate a label for. show_shape: Whether to include shape info (Out/In with D, C, R values). show_params: Whether to include parameter count. show_scope: Whether to include scope information. Returns: A multi-line label string with properly aligned information. """ # Collect all labels we'll display to calculate max width all_labels = [] # Shape labels (Out first, then In) shapes = [] shape_labels = [] if show_shape: # Out shape first shapes.append((module.out_shape.features, module.out_shape.channels, module.out_shape.repetitions)) shape_labels.append("Out") all_labels.append("Out") # In shape second (if available) if module.in_shape is not None: shapes.append((module.in_shape.features, module.in_shape.channels, module.in_shape.repetitions)) shape_labels.append("In") all_labels.append("In") # Check scope (before params in the new order) scope_str = "" if show_scope: scope_str = _format_scope_string(sorted(module.scope.query)) all_labels.append("Scope") # Check params (last in the new order) param_count = 0 if show_params: param_count = _count_parameters(module) if param_count > 0: all_labels.append("Params") # Calculate max label width across all labels max_label_width = max(len(label) for label in all_labels) if all_labels else 0 lines = [] # Shape information (Out/In) if show_shape and shapes: # Calculate max width for each parameter value across all shapes max_d_width = max(len(str(s[0])) for s in shapes) max_c_width = max(len(str(s[1])) for s in shapes) max_r_width = max(len(str(s[2])) for s in shapes) # Build aligned shape lines for label, (d, c, r) in zip(shape_labels, shapes): # Right-align the label padded_label = label.rjust(max_label_width) # Left-pad each parameter value to align with the widest d_str = str(d).ljust(max_d_width) c_str = str(c).ljust(max_c_width) r_str = str(r).ljust(max_r_width) lines.append(f"{padded_label}: D={d_str} C={c_str} R={r_str}") # Scope information (before params) if show_scope: padded_label = "Scope".rjust(max_label_width) lines.append(f"{padded_label}: {scope_str}") # Parameter count (last) if show_params and param_count > 0: formatted_count = _format_param_count(param_count) padded_label = "Params".rjust(max_label_width) lines.append(f"{padded_label}: {formatted_count}") # Check for extra visualization info from the module extra_info = module._extra_vis_info() if extra_info is not None: lines.append(extra_info) return "\n".join(lines) def _get_module_label( module: Module, show_scope: bool = False, show_shape: bool = False, show_params: bool = False ) -> str: """Generate a label for a module node. Uses HTML labels with bold module names and monospace font for proper alignment of shape, params, and scope information. Uses HTML table with left-aligned cells to preserve alignment. Args: module: The module to generate a label for. show_scope: Whether to include scope information. show_shape: Whether to include shape information. show_params: Whether to include parameter count. Returns: An HTML label string for the module (compatible with pydot/graphviz). """ # Get the class name class_name = module.__class__.__name__ # Build HTML table with left-aligned cells for proper alignment # Using a table structure ensures the text is left-aligned within the node rows = [] # Module name row (bold, centered as header) rows.append(f'<TR><TD ALIGN="CENTER"><font face="Courier"><B>{class_name}</B></font></TD></TR>') # Build visualization label if show_shape or show_params or show_scope: vis_label = _build_vis_label( module, show_shape=show_shape, show_params=show_params, show_scope=show_scope, ) if vis_label: # Each line becomes a left-aligned table row for line in vis_label.split("\n"): rows.append(f'<TR><TD ALIGN="LEFT"><font face="Courier">{line}</font></TD></TR>') # Build table structure with no borders and minimal spacing table = '<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">' + "".join(rows) + "</TABLE>" # Wrap in angle brackets for HTML label return "<" + table + ">" def _get_module_color(module: Module) -> str: """Get the color for a module based on its type. Uses matplotlib tab10 colormap for consistent, distinguishable colors. Related module types share colors within groups. Args: module: The module to get a color for. Returns: A color string (hex) for the module type based on tab10 colormap. """ class_name = module.__class__.__name__ # Check if this is a leaves module (all leaves modules get the same color) try: from spflow.modules.leaves import LeafModule if isinstance(module, LeafModule): return Color.GREEN except ImportError: # Fallback to class name checking if LeafModule can't be imported leaf_modules = { "Normal", "Categorical", "Bernoulli", "Poisson", "Exponential", "CondNormal", "CondCategorical", } if class_name in leaf_modules: return Color.GREEN # Define color mapping for different module types using Color enum color_map = { # Sum modules "Sum": Color.BLUE, "ElementwiseSum": Color.BLUE, "MixingLayer": Color.BLUE, "LinsumLayer": Color.BLUE, "EinsumLayer": Color.BLUE, "SumConv": Color.BLUE, # Product modules "Product": Color.ORANGE, "ElementwiseProduct": Color.ORANGE, "OuterProduct": Color.ORANGE, "ProdConv": Color.ORANGE, "ConvPc": Color.ORANGE, # Operations - Cat "Cat": Color.PINK, # Operations - Split "Split": Color.PURPLE, "SplitByIndex": Color.PURPLE, "SplitConsecutive": Color.PURPLE, "SplitInterleaved": Color.PURPLE, # Operations - Factorize "Factorize": Color.BROWN, # RAT-SPN "RatSPN": Color.RED, } # Return color if found, otherwise use GRAY as default return color_map.get(class_name, Color.GRAY)