Source code for spflow.utils.replace
"""Context manager for temporarily replacing class methods.
Provides a clean API for method substitution with automatic handling of
decorators like @cached. Useful for testing, debugging, and experimentation.
"""
from __future__ import annotations
import inspect
from contextlib import contextmanager
from typing import Callable, TypeVar
from spflow.exceptions import InvalidTypeError
T = TypeVar("T")
[docs]
@contextmanager
def replace(method_ref: Callable, replacement_func: Callable):
"""Temporarily replace a class method with a custom implementation.
Automatically detects and preserves decorators (like @cached) by re-applying
them to the replacement function. Works at the class level, affecting all
instances of the class.
Args:
method_ref: Reference to the method to replace (e.g., Sum.log_likelihood).
This should be an unbound method (accessed via the class, not instance).
replacement_func: The function to use as replacement. Must have a compatible
signature with the original method (including 'self' as first parameter).
Yields:
None.
Example:
::
def my_custom_ll(self, data, cache=None):
# Custom implementation
return torch.ones(len(data))
model = Product(Sum(Product(Normal(...))))
# Normal inference
model.log_likelihood(data)
# Use custom implementation for Sum modules
with replace(Sum.log_likelihood, my_custom_ll):
model.log_likelihood(data) # Sum instances now use my_custom_ll
Raises:
ValueError: If the class cannot be inferred from the method reference.
TypeError: If method_ref is not a valid method reference.
"""
# Extract class and method name from method reference
target_class, method_name = _extract_class_and_name(method_ref)
# Get the original method
original_method = getattr(target_class, method_name)
# Detect if the original method is decorated (e.g., with @cached)
# Methods decorated with @functools.wraps have __wrapped__ attribute
is_decorated = hasattr(original_method, "__wrapped__")
# Prepare the replacement method
if is_decorated:
# Re-apply the decorator to the replacement function
# Import here to avoid circular imports
from spflow.utils.cache import cached
new_method = cached(replacement_func)
else:
new_method = replacement_func
# Replace the method on the class
setattr(target_class, method_name, new_method)
try:
yield
finally:
# Restore the original method
setattr(target_class, method_name, original_method)
def _extract_class_and_name(method_ref: Callable) -> tuple[type, str]:
"""Extract the owner class and method name from a method reference.
Args:
method_ref: An unbound method reference (e.g., Sum.log_likelihood).
Returns:
A tuple of (owner_class, method_name).
Raises:
ValueError: If the class cannot be extracted from the method reference.
TypeError: If method_ref is not a valid method reference.
"""
# Verify it's a callable
if not callable(method_ref):
raise InvalidTypeError(f"Expected a callable method reference, got {type(method_ref)}")
# Get the method name
method_name = getattr(method_ref, "__name__", None)
if method_name is None:
raise InvalidTypeError("Method reference must have a __name__ attribute")
# Get the qualified name to extract the class
# __qualname__ looks like "ClassName.method_name"
qualname = getattr(method_ref, "__qualname__", None)
if qualname is None:
raise ValueError("Cannot determine class from method reference: missing __qualname__")
# Parse the qualified name to get the class name
if "." not in qualname:
raise ValueError(
f"Cannot determine class from method reference: "
f"__qualname__='{qualname}' has no '.' separator"
)
class_name = qualname.rsplit(".", 1)[0]
# Get the module containing the method
module = inspect.getmodule(method_ref)
if module is None:
raise ValueError("Cannot determine module for method reference")
# Try to get the class from the module's globals
target_class = getattr(module, class_name, None)
if target_class is None:
raise ValueError(
f"Cannot find class '{class_name}' in module '{module.__name__}'. "
f"Method reference qualname: {qualname}"
)
return target_class, method_name