Source code for spflow.modules.products.elementwise_product

from __future__ import annotations

import numpy as np
import torch
from torch import Tensor

from spflow.exceptions import ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
from spflow.modules.products.base_product import BaseProduct
from spflow.utils.cache import Cache, cached


[docs] class ElementwiseProduct(BaseProduct): """Elementwise product with automatic broadcasting. Multiplies inputs element-wise with broadcasting support. All input scopes must be pairwise disjoint. Commonly used in RAT-SPN architectures. """
[docs] def __init__( self, inputs: Module | tuple[Module, Module] | list[Module], num_splits: int | None = 2, split_mode: SplitMode | None = None, ) -> None: """Initialize elementwise product. Args: inputs: List of Modules. The scopes for all child modules need to be pair-wise disjoint. (1) If `inputs` is a list of Modules, they have to be of disjoint scopes and have equal number of features or a single feature wich will the be broadcast and an equal number of channels or a single channel which will be broadcast. Example shapes: inputs = ((3, 4), (3, 4)) output = (3, 4) inputs = ((3, 4), (3, 1)) output = (3, 4) # broadcasted inputs = ((3, 4), (1, 4)) output = (3, 4) # broadcasted inputs = ((3, 1), (1, 4)) output = (3, 4) # broadcasted num_splits: Number of splits when wrapping single input in Split. split_mode: Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=num_splits) if not specified. Raises: ValueError: Invalid arguments. """ # Handle single non-Split input: wrap with split_mode if isinstance(inputs, Module) and not isinstance(inputs, (list, tuple, Split)): if split_mode is not None: inputs = split_mode.create(inputs) else: inputs = SplitConsecutive(inputs, num_splits=num_splits) super().__init__(inputs=inputs) # Check if all inputs either have equal number of out_channels or 1 (using in_shape.channels set by BaseProduct) if not all(inp.out_shape.channels in (1, self.in_shape.channels) for inp in self.inputs): raise ShapeError( f"Inputs must have equal number of channels or one of them must be '1', but were {[inp.out_shape.channels for inp in self.inputs]}" ) if self.num_splits is None: self.num_splits = num_splits self.check_shapes() # Shape computation: compute out_shape based on elementwise product input_features = self.inputs[0].out_shape.features if input_features == 1: out_features = 1 elif self.input_is_split: out_features = int(input_features // self.num_splits) else: out_features = input_features # out_channels is max since one input can have single channel (broadcast) out_channels = self.in_shape.channels self.out_shape = ModuleShape(out_features, out_channels, self.in_shape.repetitions)
[docs] def check_shapes(self): """Check if input shapes are compatible for broadcasting.""" inputs = self.inputs if self.input_is_split: if self.num_splits != inputs[0].num_splits: raise ValueError("num_splits must be the same for all inputs") out_f = self.inputs[0].out_shape.features out_c = self.inputs[0].out_shape.channels if out_f == 1: out_f_computed = 1 else: out_f_computed = int(out_f // self.num_splits) shapes = inputs[0].get_out_shapes((out_f_computed, out_c)) else: shapes = [(inp.out_shape.features, inp.out_shape.channels) for inp in inputs] if not shapes: return False # No shapes to check # Extract dimensions dim0_values = [shape[0] for shape in shapes] dim1_values = [shape[1] for shape in shapes] # Condition 1: All shapes are the same if all(shape == shapes[0] for shape in shapes): return True # Condition 2: dim0 is the same and dim1 is the same or has the value 1 if len(set(dim0_values) - {1}) == 1: return True if len(set(dim1_values) - {1}) == 1: return True # Condition 4: In dim0 every value except one has the value 1, # and in dim1 every value except one has the value 1 dim0_non_ones = [value for value in dim0_values if value != 1] dim1_non_ones = [value for value in dim1_values if value != 1] if len(dim0_non_ones) <= 1 and len(dim1_non_ones) <= 1: return True # If none of the conditions are satisfied raise ShapeError(f"the shapes of the inputs {shapes} are not broadcastable")
@property def feature_to_scope(self) -> np.ndarray: out = [] for r in range(self.out_shape.repetitions): if self.input_is_split: scope_lists_r = self.inputs[0].feature_to_scope[ ..., r ] # Shape: (num_features_per_split, num_splits) scope_lists_r = [scope_lists_r[:, i] for i in range(self.num_splits)] else: scope_lists_r = [ module.feature_to_scope[..., r] for module in self.inputs ] # Shape: (num_features_per_split, num_splits) feature_to_scope_r = [] # Group elements by index grouped_scopes_r = list(zip(*scope_lists_r)) for scopes_r in grouped_scopes_r: feature_to_scope_r.append(Scope.join_all(scopes_r)) out.append(np.array(feature_to_scope_r)) return np.stack(out, axis=1)
[docs] def map_out_channels_to_in_channels(self, output_ids: Tensor) -> Tensor: """Map output channel indices to input channel indices. Args: output_ids: Tensor of output channel indices to map. Returns: Tensor: Mapped input channel indices. """ if self.input_is_split: num_splits = self.num_splits if isinstance(self.inputs[0], SplitConsecutive): return output_ids.repeat((1, num_splits)).unsqueeze(-1) elif isinstance(self.inputs[0], SplitInterleaved): return output_ids.repeat_interleave(num_splits, dim=1).unsqueeze(-1) else: raise NotImplementedError("Other Split types are not implemented yet.") else: num_splits = len(self.inputs) return output_ids.unsqueeze(-1).repeat(1, 1, num_splits)
[docs] def map_out_mask_to_in_mask(self, mask: Tensor) -> Tensor: """Map output mask to input mask. Args: mask: Output mask tensor to map. Returns: Tensor: Mapped input mask tensor. """ if self.input_is_split: num_splits = self.num_splits if isinstance(self.inputs[0], SplitConsecutive): return mask.repeat((1, num_splits)).unsqueeze(-1) elif isinstance(self.inputs[0], SplitInterleaved): return mask.repeat_interleave(num_splits, dim=1).unsqueeze(-1) else: raise NotImplementedError("Other Split types are not implemented yet.") else: num_splits = len(self.inputs) return mask.unsqueeze(-1).repeat(1, 1, num_splits)
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log likelihood by element-wise summing inputs. Args: data: Input data tensor. cache: Optional cache for storing intermediate computations. Returns: Tensor: Computed log likelihood values. """ # initialize cache lls = self._get_input_log_likelihoods(data, cache) # Check if we need to expand to enable broadcasting along channels for i, ll in enumerate(lls): if ll.shape[2] == 1: if ll.ndim == 4: lls[i] = ll.expand(-1, -1, self.out_shape.channels, -1) else: lls[i] = ll.expand(-1, -1, self.out_shape.channels) # Compute the elementwise sum of left and right split output = torch.sum(torch.stack(lls, dim=-1), dim=-1) output = output.view( output.size(0), self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions ) return output