Source code for spflow.modules.ops.cat

from __future__ import annotations

from typing import Optional

import numpy as np
import torch
from torch import Tensor, nn

from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import (
    SamplingContext,
    init_default_sampling_context,
)


[docs] class Cat(Module):
[docs] def __init__(self, inputs: list[Module], dim: int = -1): """Initialize concatenation operation. Args: inputs: Modules to concatenate. dim: Concatenation dimension (0=batch, 1=feature, 2=channel). """ super().__init__() self.inputs = nn.ModuleList(inputs) self.dim = dim if self.dim == 1: # Check if all inputs have the same number of channels if not all( [module.out_shape.channels == self.inputs[0].out_shape.channels for module in self.inputs] ): raise ValueError("All inputs must have the same number of channels.") # Check that all scopes are disjoint if not Scope.all_pairwise_disjoint([module.scope for module in self.inputs]): raise ValueError("All inputs must have disjoint scopes.") # Scope is the join of all input scopes self._scope = Scope.join_all([inp.scope for inp in self.inputs]) elif self.dim == 2: # Check if all inputs have the same number of features and scopes if not all( [module.out_shape.features == self.inputs[0].out_shape.features for module in self.inputs] ): raise ValueError("All inputs must have the same number of features.") if not Scope.all_equal([module.scope for module in self.inputs]): raise ValueError("All inputs must have the same scope.") # Scope is the same as all inputs self._scope = self.inputs[0].scope else: raise ValueError("Invalid dimension for concatenation.") # Shape computation self.in_shape = self.inputs[0].out_shape if self.dim == 1: out_features = sum([module.out_shape.features for module in self.inputs]) out_channels = self.inputs[0].out_shape.channels else: # dim == 2 out_features = self.inputs[0].out_shape.features out_channels = sum([module.out_shape.channels for module in self.inputs]) self.out_shape = ModuleShape(out_features, out_channels, self.inputs[0].out_shape.repetitions)
@property def feature_to_scope(self) -> np.ndarray: if self.dim == 1: # Concatenate along features dimension (axis=0) since we're concatenating features return np.concatenate([module.feature_to_scope for module in self.inputs], axis=0) else: return self.inputs[0].feature_to_scope def extra_repr(self) -> str: return f"{super().extra_repr()}, dim={self.dim}"
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log likelihood by concatenating input log-likelihoods. Args: data: Input data tensor. cache: Optional cache for storing intermediate results. Returns: Tensor: Concatenated log-likelihood tensor. """ # get log likelihoods for all inputs lls = [] for input_module in self.inputs: input_ll = input_module.log_likelihood(data, cache=cache) lls.append(input_ll) # Concatenate log likelihoods output = torch.cat(lls, dim=self.dim) return output
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: Optional[SamplingContext] = None, ) -> Tensor: """Generate samples by delegating to concatenated inputs. Args: num_samples: Number of samples to generate. data: Optional data tensor to store samples. is_mpe: Whether to perform most probable explanation sampling. cache: Optional cache for storing intermediate results. sampling_ctx: Sampling context for controlling sample generation. Returns: Tensor: Generated samples tensor. """ # Prepare data tensor data = self._prepare_sample_data(num_samples, data) sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0]) if self.dim == 1: # When concatenating features (dim=1), we need to split the sampling context # for each input module based on which INTERNAL feature indices belong to that module. # # IMPORTANT: sampling_ctx.channel_index and mask are indexed by internal feature # position (0, 1, 2, ..., total_features-1), NOT by scope indices. Each module's # features occupy a contiguous range in the concatenated output. channel_index_per_module = [] mask_per_module = [] feature_offset = 0 for module in self.inputs: # Get the internal feature indices for this module (contiguous range) num_features = module.out_shape.features feature_indices = list(range(feature_offset, feature_offset + num_features)) channel_index_per_module.append(sampling_ctx.channel_index[:, feature_indices]) mask_per_module.append(sampling_ctx.mask[:, feature_indices]) feature_offset += num_features elif self.dim == 2: # Concatenation happens at out_channels # Therefore, we need to use modulo to get the correct output_ids channel_index_per_module = [] mask_per_module = [] # Get split assignments split_size = self.out_shape.channels // len(self.inputs) split_assignment = sampling_ctx.channel_index // split_size for i, _ in enumerate(self.inputs): oids = sampling_ctx.channel_index oids_mod = oids.remainder(split_size) channel_index_per_module.append(oids_mod) mask = (split_assignment == i) & sampling_ctx.mask mask_per_module.append(mask) else: raise ValueError("Invalid dimension for concatenation.") # Iterate over inputs for i in range(len(self.inputs)): input_module = self.inputs[i] sampling_ctx_copy = sampling_ctx.copy() sampling_ctx_copy.update(channel_index=channel_index_per_module[i], mask=mask_per_module[i]) input_module.sample( data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx_copy, ) return data
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Optional["Module"]: """Marginalize out specified random variables. Args: marg_rvs: List of random variable indices to marginalize. prune: Whether to prune unnecessary modules after marginalization. cache: Optional cache for storing intermediate results. Returns: Optional[Module]: Marginalized module or None if fully marginalized. """ # compute module scope (same for all outputs) module_scope = self.scope mutual_rvs = set(module_scope.query).intersection(set(marg_rvs)) # Node scope is only being partially marginalized if mutual_rvs: inputs = [] # marginalize child modules for input_module in self.inputs: marg_child_module = input_module.marginalize(marg_rvs, prune=prune, cache=cache) # if marginalized child is not None if marg_child_module: inputs.append(marg_child_module) # if all children were marginalized, return None if len(inputs) == 0: return None # if only a single input survived marginalization, return it if pruning is enabled if prune and len(inputs) == 1: return inputs[0] return Cat(inputs=inputs, dim=self.dim) else: return self