Source code for spflow.utils.cache
"""Cache utilities for efficient inference in SPFlow.
Provides thread-safe caching to optimize inference, learning, and sampling by
avoiding redundant computations in DAG traversals. Uses WeakKeyDictionary to
allow garbage collection of cached modules.
"""
from __future__ import annotations
import functools
import threading
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from weakref import WeakKeyDictionary
if TYPE_CHECKING: # Avoid circular imports
from spflow.modules.module import Module
T = TypeVar("T")
[docs]
class Cache:
"""Thread-safe cache with per-method-type locking and weak key references.
Uses WeakKeyDictionary to store cached values keyed by module instances,
allowing garbage collection when modules are no longer referenced elsewhere.
Attributes:
extras: Mutable dictionary for user-defined, per-traversal state. This can be
used to pass custom information through recursive module calls without
changing public method signatures.
"""
[docs]
def __init__(self):
"""Initialize cache with per-method locks and storage."""
self._locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
self._cache: dict[str, WeakKeyDictionary[Module, Any]] = {}
self.extras: dict[str, Any] = {}
[docs]
def get(self, method_name: str, module: Module) -> Any | None:
"""Retrieve cached value for a module.
Args:
method_name: Name of the cached method (e.g., "log_likelihood").
module: Module instance to use as cache key.
Returns:
Cached value if present, None otherwise.
"""
if method_name not in self._cache:
return None
return self._cache[method_name].get(module)
[docs]
def set(self, method_name: str, module: Module, value: Any) -> None:
"""Store a value in cache for a module.
Args:
method_name: Name of the cached method (e.g., "log_likelihood").
module: Module instance to use as cache key.
value: Value to cache.
"""
with self._locks[method_name]:
if method_name not in self._cache:
self._cache[method_name] = WeakKeyDictionary()
self._cache[method_name][module] = value
def __getitem__(self, method_name: str) -> WeakKeyDictionary[Module, Any]:
"""Get the cache dictionary for a method type (for backward compatibility).
Args:
method_name: Name of the cached method (e.g., "log_likelihood").
Returns:
WeakKeyDictionary for this method type.
"""
if method_name not in self._cache:
with self._locks[method_name]:
if method_name not in self._cache:
self._cache[method_name] = WeakKeyDictionary()
return self._cache[method_name]
def __contains__(self, method_name: str) -> bool:
"""Check if a method type has any cached values.
Args:
method_name: Name of the cached method.
Returns:
True if the method type has cached entries.
"""
return method_name in self._cache and len(self._cache[method_name]) > 0
def cached(func: Callable[..., T]) -> Callable[..., T]:
"""Decorator for automatically caching method results.
Automatically uses the function's __name__ attribute as the cache key.
The decorated method must have a `cache` parameter (can be None).
Example:
```python
@cached
def log_likelihood(self, data, cache=None):
# Computation here
return result
```
Args:
func: The function to decorate.
Returns:
Decorated function with caching functionality.
"""
method_name = func.__name__
@functools.wraps(func)
def wrapper(self: Module, *args, cache: Cache | None = None, **kwargs) -> T:
# Initialize cache if not provided
if cache is None:
cache = Cache()
# Check cache first
cached_value = cache.get(method_name, self)
if cached_value is not None:
return cached_value
# Compute result
result = func(self, *args, cache=cache, **kwargs)
# Store in cache
cache.set(method_name, self, result)
return result
return wrapper