Source code for spflow.modules.products.elementwise_product

from __future__ import annotations

import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor

from spflow.exceptions import ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
from spflow.modules.products.base_product import BaseProduct
from spflow.utils.cache import Cache, cached


[docs] class ElementwiseProduct(BaseProduct): """Elementwise product with automatic broadcasting. Multiplies inputs element-wise with broadcasting support. All input scopes must be pairwise disjoint. Commonly used in RAT-SPN architectures. """
[docs] def __init__( self, inputs: Module | tuple[Module, Module] | list[Module], num_splits: int | None = 2, split_mode: SplitMode | None = None, ) -> None: """Initialize elementwise product. Args: inputs: List of Modules. The scopes for all child modules need to be pair-wise disjoint. (1) If `inputs` is a list of Modules, they have to be of disjoint scopes and have equal number of features or a single feature wich will the be broadcast and an equal number of channels or a single channel which will be broadcast. Example shapes: inputs = ((3, 4), (3, 4)) output = (3, 4) inputs = ((3, 4), (3, 1)) output = (3, 4) # broadcasted inputs = ((3, 4), (1, 4)) output = (3, 4) # broadcasted inputs = ((3, 1), (1, 4)) output = (3, 4) # broadcasted num_splits: Number of splits when wrapping single input in Split. split_mode: Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=num_splits) if not specified. Raises: ValueError: Invalid arguments. """ # Handle single non-Split input: wrap with split_mode if isinstance(inputs, Module) and not isinstance(inputs, (list, tuple, Split)): if split_mode is not None: inputs = split_mode.create(inputs) else: inputs = SplitConsecutive(inputs, num_splits=num_splits) super().__init__(inputs=inputs) # Check if all inputs either have equal number of out_channels or 1 (using in_shape.channels set by BaseProduct) if not all(inp.out_shape.channels in (1, self.in_shape.channels) for inp in self.inputs): raise ShapeError( f"Inputs must have equal number of channels or one of them must be '1', but were {[inp.out_shape.channels for inp in self.inputs]}" ) if self.num_splits is None: self.num_splits = num_splits self.check_shapes() # Shape computation: compute out_shape based on elementwise product input_features = self.inputs[0].out_shape.features if input_features == 1: out_features = 1 elif self.input_is_split: out_features = int(input_features // self.num_splits) else: out_features = input_features # out_channels is max since one input can have single channel (broadcast) out_channels = self.in_shape.channels self.out_shape = ModuleShape(out_features, out_channels, self.in_shape.repetitions)
[docs] def check_shapes(self): """Check if input shapes are compatible for broadcasting.""" inputs = self.inputs if self.input_is_split: if self.num_splits != inputs[0].num_splits: raise ValueError("num_splits must be the same for all inputs") out_f = self.inputs[0].out_shape.features out_c = self.inputs[0].out_shape.channels if out_f == 1: out_f_computed = 1 else: out_f_computed = int(out_f // self.num_splits) shapes = inputs[0].get_out_shapes((out_f_computed, out_c)) else: shapes = [(inp.out_shape.features, inp.out_shape.channels) for inp in inputs] if not shapes: return False # No shapes to check # Extract dimensions dim0_values = [shape[0] for shape in shapes] dim1_values = [shape[1] for shape in shapes] # Condition 1: All shapes are the same if all(shape == shapes[0] for shape in shapes): return True # Condition 2: dim0 is the same and dim1 is the same or has the value 1 if len(set(dim0_values) - {1}) == 1: return True if len(set(dim1_values) - {1}) == 1: return True # Condition 4: In dim0 every value except one has the value 1, # and in dim1 every value except one has the value 1 dim0_non_ones = [value for value in dim0_values if value != 1] dim1_non_ones = [value for value in dim1_values if value != 1] if len(dim0_non_ones) <= 1 and len(dim1_non_ones) <= 1: return True # If none of the conditions are satisfied raise ShapeError(f"the shapes of the inputs {shapes} are not broadcastable")
@property def feature_to_scope(self) -> np.ndarray: out = [] for r in range(self.out_shape.repetitions): if self.input_is_split: scope_lists_r = self.inputs[0].feature_to_scope[ ..., r ] # Shape: (num_features_per_split, num_splits) scope_lists_r = [scope_lists_r[:, i] for i in range(self.num_splits)] else: scope_lists_r = [ module.feature_to_scope[..., r] for module in self.inputs ] # Shape: (num_features_per_split, num_splits) feature_to_scope_r = [] # Group elements by index grouped_scopes_r = list(zip(*scope_lists_r)) for scopes_r in grouped_scopes_r: feature_to_scope_r.append(Scope.join_all(scopes_r)) out.append(np.array(feature_to_scope_r)) return np.stack(out, axis=1)
[docs] def map_out_channels_to_in_channels(self, output_ids: Tensor) -> Tensor: """Map output channel indices to input channel indices. Args: output_ids: Tensor of output channel indices to map. Returns: Tensor: Mapped input channel indices. """ if self.input_is_split: num_splits = self.num_splits if output_ids.is_floating_point(): if isinstance(self.inputs[0], SplitConsecutive): expanded_ids = repeat(output_ids, "b f c -> b i f c", i=num_splits) return rearrange(expanded_ids, "b i f c -> b (i f) 1 c") elif isinstance(self.inputs[0], SplitInterleaved): expanded_ids = repeat(output_ids, "b f c -> b f i c", i=num_splits) return rearrange(expanded_ids, "b f i c -> b (f i) 1 c") else: raise NotImplementedError("Other Split types are not implemented yet.") if isinstance(self.inputs[0], SplitConsecutive): expanded_ids = repeat(output_ids, "b f -> b i f", i=num_splits) return rearrange(expanded_ids, "b i f -> b (i f) 1") elif isinstance(self.inputs[0], SplitInterleaved): expanded_ids = repeat(output_ids, "b f -> b f i", i=num_splits) return rearrange(expanded_ids, "b f i -> b (f i) 1") else: raise NotImplementedError("Other Split types are not implemented yet.") num_inputs = len(self.inputs) if not output_ids.is_floating_point(): mapped = repeat(output_ids, "b f -> b f i", i=num_inputs) for i, inp in enumerate(self.inputs): if int(inp.out_shape.channels) == 1: mapped[..., i] = 0 return mapped batch_size = int(output_ids.shape[0]) num_features = int(output_ids.shape[1]) max_channels = int(self.in_shape.channels) mapped = output_ids.new_zeros((batch_size, num_features, num_inputs, max_channels)) for i, inp in enumerate(self.inputs): child_channels = int(inp.out_shape.channels) if child_channels == 1: mapped[:, :, i, 0] = 1.0 else: mapped[:, :, i, :child_channels] = output_ids[:, :, :child_channels] return mapped
[docs] def map_out_mask_to_in_mask(self, mask: Tensor) -> Tensor: """Map output mask to input mask. Args: mask: Output mask tensor to map. Returns: Tensor: Mapped input mask tensor. """ if self.input_is_split: num_splits = self.num_splits if isinstance(self.inputs[0], SplitConsecutive): expanded_mask = repeat(mask, "b f -> b i f", i=num_splits) return rearrange(expanded_mask, "b i f -> b (i f) 1") elif isinstance(self.inputs[0], SplitInterleaved): expanded_mask = repeat(mask, "b f -> b f i", i=num_splits) return rearrange(expanded_mask, "b f i -> b (f i) 1") else: raise NotImplementedError("Other Split types are not implemented yet.") else: num_splits = len(self.inputs) return repeat(mask, "b f -> b f i", i=num_splits)
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log likelihood by element-wise summing inputs. Args: data: Input data tensor. cache: Optional cache for storing intermediate computations. Returns: Tensor: Computed log likelihood values. """ # initialize cache lls = self._get_input_log_likelihoods(data, cache) # Check if we need to expand to enable broadcasting along channels for i, ll in enumerate(lls): if ll.shape[2] == 1: num_output_channels = self.out_shape.channels if ll.ndim == 4: lls[i] = repeat(ll, "b f 1 r -> b f c r", c=num_output_channels) else: lls[i] = repeat(ll, "b f 1 -> b f c", c=num_output_channels) # Compute the elementwise sum of left and right split output = torch.sum(torch.stack(lls, dim=-1), dim=-1) output = output.view( output.size(0), self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions ) return output