Source code for blox.blocks

"""Standard neural network building blocks.

This module provides a collection of pre-built, functional layers ready
for use.

Included layers:

* **Core**: ``Embed``, ``Linear``.
* **Convolution**: ``Conv``, ``ConvTranspose``.
* **Recurrent**: ``LSTM``, ``GRU``.
* **Normalization**: ``LayerNorm``, ``RMSNorm``, ``BatchNorm``.
* **Regularization**: ``Dropout``.
* **Containers**: ``Sequential``.
* **Pooling**: ``max_pool``, ``min_pool``, ``avg_pool``.

Blocks are usually subclasses of :class:`blox.Module` which strictly
follows the ``(params, inputs) -> (outputs, params)`` functional
signature.
"""

import math
from typing import Any, Callable, NamedTuple, Sequence

import jax
import jax.numpy as jnp

from . import interfaces as bx

Initializer = jax.nn.initializers.Initializer
PaddingLike = str | Sequence[tuple[int, int]]


[docs] class Embed(bx.Module): """Embedding layer that maps integer indices to dense vectors. Supports weight tying for language models via the ``attend`` method, which applies the transpose of the embedding matrix (useful for output projections). Example:: embed = Embed( graph.child('embed'), num_embeddings=10000, embedding_size=512, rng=rng, ) # Forward pass: indices -> embeddings. embeddings, params = embed(params, token_ids) # Weight-tied output projection: hidden -> logits. logits, params = embed.attend(params, hidden_states) """ def __init__( self, graph: bx.Graph, num_embeddings: int, embedding_size: int, rng: bx.Rng | None, embedding_init: Initializer = jax.nn.initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0 ), embedding_metadata: dict[str, Any] | None = None, ) -> None: """Initializes the Embed module. Args: graph: The graph node for this module. num_embeddings: Size of the vocabulary (number of unique tokens). embedding_size: Dimensionality of the embedding vectors. rng: Rng module for random initialization. If embedding_init is constant, Rng is not required (but still recommended). embedding_init: Initializer for the embedding matrix. embedding_metadata: Optional metadata for the embedding parameter. """ super().__init__(graph) self.num_embeddings = num_embeddings self.embedding_size = embedding_size self.rng = rng self.embedding_init = embedding_init self.embedding_metadata = embedding_metadata def __call__( self, params: bx.Params, indices: jax.Array, ) -> tuple[jax.Array, bx.Params]: """Looks up embeddings for the given indices. Args: params: The parameters container. indices: Integer array of token indices. Shape [...]. Returns: A tuple (embeddings, params). Embeddings have shape [..., embedding_size]. """ embedding_matrix, params = self.get_param( params=params, name='embedding', shape=(self.num_embeddings, self.embedding_size), init=self.embedding_init, metadata=self.embedding_metadata, rng=self.rng, ) return embedding_matrix[indices], params
[docs] def attend( self, params: bx.Params, inputs: jax.Array, ) -> tuple[jax.Array, bx.Params]: """Applies the transpose of the embedding matrix (for weight tying). This is commonly used in language models where the output projection shares weights with the input embedding. Args: params: The parameters container. inputs: Input array of shape [..., embedding_size]. Returns: A tuple (logits, params). Logits have shape [..., num_embeddings]. """ embedding_matrix, params = self.get_param( params=params, name='embedding', shape=(self.num_embeddings, self.embedding_size), init=self.embedding_init, metadata=self.embedding_metadata, rng=self.rng, ) # inputs @ embedding_matrix.T return jnp.dot(inputs, embedding_matrix.T), params
[docs] class Linear(bx.Module): """A standard linear transformation layer. Computes ``output = input @ kernel + bias``. Supports model parallelism via metadata. Example for sharding weights:: linear = Linear( graph.child('linear'), output_size=1024, rng=rng, kernel_metadata={'sharding': (None, 'model')}, # Shard output. bias_metadata={'sharding': ('model',)}, ) """ def __init__( self, graph: bx.Graph, output_size: int, rng: bx.Rng | None, use_bias: bool = True, kernel_init: Initializer = jax.nn.initializers.lecun_normal(), bias_init: Initializer = jax.nn.initializers.zeros, kernel_metadata: dict[str, Any] | None = None, bias_metadata: dict[str, Any] | None = None, ) -> None: """Initializes the Linear module. Args: graph: The graph node for this module. output_size: The dimensionality of the output features. rng: Rng module for random initialization. If kernel_init and bias_init are constant, Rng is not required (but still recommended). use_bias: Whether to add a learnable bias vector. kernel_init: Initializer for the weight matrix. bias_init: Initializer for the bias vector. kernel_metadata: Optional metadata for the kernel parameter. Common keys: - 'sharding': tuple like (None, 'model') for model parallelism bias_metadata: Optional metadata for the bias parameter. Common keys: - 'sharding': tuple like ('model',) for model parallelism """ super().__init__(graph) self.output_size = output_size self.rng = rng self.use_bias = use_bias self.kernel_init = kernel_init self.bias_init = bias_init self.kernel_metadata = kernel_metadata self.bias_metadata = bias_metadata def __call__( self, params: bx.Params, inputs: jax.Array, precision: jax.lax.Precision | None = None, ) -> tuple[jax.Array, bx.Params]: """Applies the linear transformation. Args: params: The parameters container. inputs: The input array. Must have at least one dimension. Shape should be [..., input_features]. precision: Optional precision for the matrix multiplication. Returns: A tuple (output, params). The output has shape [..., output_size]. Raises: ValueError: If the input is a scalar (rank 0). """ if not inputs.shape: raise ValueError('Input must not be scalar.') input_size = inputs.shape[-1] kernel, params = self.get_param( params, 'kernel', (input_size, self.output_size), self.kernel_init, metadata=self.kernel_metadata, rng=self.rng, ) outputs = jnp.dot(inputs, kernel, precision=precision) if self.use_bias: bias, params = self.get_param( params, 'bias', (self.output_size,), self.bias_init, metadata=self.bias_metadata, rng=self.rng, ) bias = jnp.broadcast_to(bias, outputs.shape) outputs = outputs + bias return outputs, params
[docs] class Sequential(bx.Module): """A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. Example:: mlp = Sequential( graph.child('mlp'), [ bx.Linear(graph.child('l1'), 32), jax.nn.relu, bx.Linear(graph.child('l2'), 10), ], ) y, params = mlp(params, x) """ def __init__( self, graph: bx.Graph, layers: Sequence[bx.Module | Callable[[jax.Array], jax.Array]], ) -> None: """Initializes the Sequential module. Args: graph: The graph node for this module. layers: A list of blox Modules or callables. If a layer is a blox Module, it must accept (params, inputs) and return (output, params). If a layer is a simple callable (like jax.nn.relu), it must accept inputs and return output. """ super().__init__(graph) self.layers = layers def __call__( self, params: bx.Params, inputs: jax.Array, ) -> tuple[jax.Array, bx.Params]: """Applies the sequential model. Args: params: The parameters container. inputs: The input array. Returns: A tuple (output, params). """ x = inputs for layer in self.layers: if isinstance(layer, bx.Module): x, params = layer(params, x) else: # Assume it's a pure activation function like jax.nn.relu. x = layer(x) return x, params
[docs] class LSTMState(NamedTuple): """Holds the hidden and cell states for an LSTM.""" hidden: jax.Array cell: jax.Array
[docs] class LSTM(bx.RecurrenceBase[jax.Array, LSTMState, jax.Array, jax.Array]): r"""Long Short-Term Memory (LSTM) Recurrent Neural Network. The mathematical definition of the cell is as follows: .. math:: i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') where x is the input, h is the output of the previous time step, and c is the memory. This module implements a standard LSTM cell. It inherits from RecurrenceBase, automatically providing support for both single-step execution (`__call__`) and efficient sequence processing (`apply` with scanning). Example:: lstm = LSTM(graph.child('lstm'), hidden_size=128, rng=rng) # Initialize state first: state, params = lstm.initial_state(params, inputs) # Single-step processing (e.g., for interactive use): (outputs, state), params = lstm(params, inputs, state) # Sequence processing (uses jax.lax.scan for efficiency): (outputs, state), params = lstm.apply(params, inputs_sequence, state) """ def __init__( self, graph: bx.Graph, hidden_size: int, rng: bx.Rng | None, is_static: bool = False, ) -> None: """Initializes the LSTM. Args: graph: The graph node for this module. hidden_size: The dimensionality of the hidden and cell states. rng: Rng module for random initialization. is_static: If True, uses Python loops for sequence processing. If False, uses jax.lax.scan (default). """ super().__init__(graph, is_static) self.hidden_size = hidden_size # Using a single Linear layer to project inputs to the 4 gates (i, g, f, o). self.gates = Linear( graph.child('gates'), output_size=4 * hidden_size, rng=rng )
[docs] def initial_state( self, params: bx.Params, inputs: jax.Array ) -> tuple[LSTMState, bx.Params]: """Creates the initial zero state. Args: params: The parameters container. inputs: The input array, used to infer the batch size (dimension 0). Returns: A tuple (LSTMState, params), where both hidden and cell states are zeros. """ batch_size = inputs.shape[0] return ( LSTMState( hidden=jnp.zeros((batch_size, self.hidden_size)), cell=jnp.zeros((batch_size, self.hidden_size)), ), params, )
def __call__( self, params: bx.Params, inputs: jax.Array, prev_state: LSTMState | None, is_reset: jax.Array | None = None, is_training: bool = True, ) -> tuple[tuple[jax.Array, LSTMState], bx.Params]: """Computes a single step of the LSTM recurrence. Args: params: The parameters container. inputs: The input at the current time step. Shape [Batch, input_size]. prev_state: The previous LSTM state. Must not be None. is_reset: Optional boolean array [Batch]. If True for a batch element, the state is reset to zero *before* computing the step. is_training: Unused. Returns: A nested tuple ((output, new_state), params). The output of the LSTM is the hidden state. Raises: ValueError: If prev_state is None. """ del is_training # Currently unused. if prev_state is None: raise ValueError('The LSTM __call__ method requires a valid prev_state.') # Apply reset mask if provided. prev_state = self.maybe_reset_state(params, prev_state, inputs, is_reset) prev_h, prev_c = prev_state.hidden, prev_state.cell # Concatenate input and previous hidden state. x_and_h = jnp.concatenate([inputs, prev_h], axis=-1) # Project to gates. gated, params = self.gates(params, x_and_h) # Split into input, gate, forget, and output components. i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1) # Apply activations. f = jax.nn.sigmoid(f) c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g) h = jax.nn.sigmoid(o) * jnp.tanh(c) new_state = LSTMState(hidden=h, cell=c) # Output is h, state is (h, c). return (h, new_state), params
[docs] class GRUState(NamedTuple): """Holds the hidden state for a GRU.""" hidden: jax.Array
[docs] class GRU(bx.RecurrenceBase[jax.Array, GRUState, jax.Array, jax.Array]): r"""Gated Recurrent Unit (GRU). The mathematical definition of the cell is as follows: .. math:: r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h where x is the input and h is the output of the previous time step. Example:: gru = GRU(graph.child('gru'), hidden_size=128, rng=rng) state, params = gru.initial_state(params, inputs) (outputs, state), params = gru(params, inputs, state) """ def __init__( self, graph: bx.Graph, hidden_size: int, rng: bx.Rng | None, is_static: bool = False, ) -> None: """Initializes the GRU. Args: graph: The graph node for this module. hidden_size: The dimensionality of the hidden state. rng: Rng module for random initialization. is_static: If True, uses Python loops. If False, uses jax.lax.scan. """ super().__init__(graph, is_static) self.hidden_size = hidden_size # We use two linear layers: # Update and reset gates (z, r) computed from x and h. # Candidate hidden state (h_tilde) computed from x and (r * h). self.gates = Linear( graph.child('gates'), output_size=2 * hidden_size, rng=rng ) self.candidate = Linear( graph.child('candidate'), output_size=hidden_size, rng=rng )
[docs] def initial_state( self, params: bx.Params, inputs: jax.Array ) -> tuple[GRUState, bx.Params]: """Creates the initial zero state. Args: params: The parameters container. inputs: The input array, used to infer the batch size (dimension 0). Returns: A tuple (GRUState, params). """ batch_size = inputs.shape[0] return GRUState(hidden=jnp.zeros((batch_size, self.hidden_size))), params
def __call__( self, params: bx.Params, inputs: jax.Array, prev_state: GRUState | None, is_reset: jax.Array | None = None, is_training: bool = True, ) -> tuple[tuple[jax.Array, GRUState], bx.Params]: """Computes a single step of the GRU recurrence. Args: params: The parameters container. inputs: The input at the current time step. Shape [Batch, input_size]. prev_state: The previous GRU state. Must not be None. is_reset: Optional boolean array [Batch]. is_training: Unused. Returns: A nested tuple ((output, new_state), params). The output of the GRU is the hidden state. """ del is_training if prev_state is None: raise ValueError('The GRU __call__ method requires a valid prev_state.') prev_state = self.maybe_reset_state(params, prev_state, inputs, is_reset) prev_h = prev_state.hidden x_and_h = jnp.concatenate([inputs, prev_h], axis=-1) gates_out, params = self.gates(params, x_and_h) z, r = jnp.split(gates_out, indices_or_sections=2, axis=-1) z = jax.nn.sigmoid(z) r = jax.nn.sigmoid(r) r_h = r * prev_h x_and_rh = jnp.concatenate([inputs, r_h], axis=-1) h_tilde, params = self.candidate(params, x_and_rh) h_tilde = jnp.tanh(h_tilde) h = (1 - z) * prev_h + z * h_tilde new_state = GRUState(hidden=h) return (h, new_state), params
[docs] class Dropout(bx.Module): """Dropout layer for regularization. During training, randomly zeros elements with probability `rate` and scales the remaining elements by `1 / (1 - rate)` to maintain expected values. During inference, this layer is a no-op. Example:: dropout = Dropout(graph.child('dropout'), rate=0.5, rng=rng) y, params = dropout(params, x, is_training=True) """ def __init__( self, graph: bx.Graph, rate: float, rng: bx.Rng, ) -> None: """Initializes the Dropout module. Args: graph: The graph node for this module. rate: The probability of dropping each element (0.0 to 1.0). rng: Rng module for generating dropout masks. """ super().__init__(graph) if not 0.0 <= rate < 1.0: raise ValueError(f'Dropout rate must be in [0.0, 1.0), got {rate}.') self.rate = rate self.rng = rng def __call__( self, params: bx.Params, inputs: jax.Array, is_training: bool = True, ) -> tuple[jax.Array, bx.Params]: """Applies dropout to the inputs. Args: params: The parameters container. inputs: The input array. is_training: If True, applies dropout. If False, returns inputs unchanged. Returns: A tuple (output, params). """ if not is_training or self.rate == 0.0: return inputs, params key, params = self.rng(params) keep_rate = 1.0 - self.rate mask = jax.random.bernoulli(key, keep_rate, inputs.shape) return inputs * mask / keep_rate, params
[docs] class LayerNorm(bx.Module): """Layer Normalization. Normalizes over the last axis (features) of the input. Supports cross-device statistics aggregation via axis_name for use with jax.shard_map. """ def __init__( self, graph: bx.Graph, epsilon: float = 1e-5, use_scale: bool = True, use_bias: bool = True, scale_init: Initializer = jax.nn.initializers.ones, bias_init: Initializer = jax.nn.initializers.zeros, axis_name: str | None = None, axis_index_groups: Sequence[Sequence[int]] | None = None, rng: bx.Rng | None = None, ) -> None: """Initializes the LayerNorm module. Args: graph: The graph node for this module. epsilon: Small constant for numerical stability. use_scale: Whether to use a learnable scale parameter. use_bias: Whether to use a learnable bias parameter. scale_init: Initializer for scale. bias_init: Initializer for bias. axis_name: The axis name used to combine statistics from multiple devices. See jax.shard_map for a description of axis names. axis_index_groups: Groups of axis indices within that named axis representing subsets of devices to reduce over. For example, [[0, 1], [2, 3]] would independently normalize over the first two and last two devices. See jax.lax.psum for more details. rng: Rng module. Required if using stochastic initializers for scale/bias. """ super().__init__(graph) self.epsilon = epsilon self.use_scale = use_scale self.use_bias = use_bias self.scale_init = scale_init self.bias_init = bias_init self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.rng = rng def __call__( self, params: bx.Params, inputs: jax.Array, ) -> tuple[jax.Array, bx.Params]: """Applies layer normalization. Args: params: The parameters container. inputs: The input array with shape [..., features]. Returns: A tuple (normalized_output, params). """ features = inputs.shape[-1] # Compute mean and variance over last axis. mean = jnp.mean(inputs, axis=-1, keepdims=True) var = jnp.var(inputs, axis=-1, keepdims=True) # Cross-device aggregation if axis_name is provided. if self.axis_name is not None: mean = jax.lax.pmean( mean, self.axis_name, axis_index_groups=self.axis_index_groups ) var = jax.lax.pmean( var, self.axis_name, axis_index_groups=self.axis_index_groups ) # Normalize. normalized = (inputs - mean) / jnp.sqrt(var + self.epsilon) # Scale and shift. if self.use_scale: scale, params = self.get_param( params=params, name='scale', shape=(features,), init=self.scale_init, rng=self.rng, ) normalized = normalized * scale if self.use_bias: bias, params = self.get_param( params=params, name='bias', shape=(features,), init=self.bias_init, rng=self.rng, ) normalized = normalized + bias return normalized, params
[docs] class RMSNorm(bx.Module): """Root Mean Square Layer Normalization. Normalizes using only the RMS of the input (no mean subtraction). This is computationally simpler than LayerNorm. Supports cross-device statistics aggregation via axis_name for use with jax.shard_map. """ def __init__( self, graph: bx.Graph, epsilon: float = 1e-5, use_scale: bool = True, scale_init: Initializer = jax.nn.initializers.ones, axis_name: str | None = None, axis_index_groups: Sequence[Sequence[int]] | None = None, rng: bx.Rng | None = None, ) -> None: """Initializes the RMSNorm module. Args: graph: The graph node for this module. epsilon: Small constant for numerical stability. use_scale: Whether to use a learnable scale parameter. scale_init: Initializer for scale. axis_name: The axis name used to combine statistics from multiple devices. See jax.shard_map for a description of axis names. axis_index_groups: Groups of axis indices within that named axis representing subsets of devices to reduce over. For example, [[0, 1], [2, 3]] would independently normalize over the first two and last two devices. See jax.lax.psum for more details. rng: Rng module. Required if using stochastic initializers for scale. """ super().__init__(graph) self.epsilon = epsilon self.use_scale = use_scale self.scale_init = scale_init self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.rng = rng def __call__( self, params: bx.Params, inputs: jax.Array, ) -> tuple[jax.Array, bx.Params]: """Applies RMS normalization. Args: params: The parameters container. inputs: The input array with shape [..., features]. Returns: A tuple (normalized_output, params). """ features = inputs.shape[-1] # Compute mean of squares. mean_sq = jnp.mean(inputs**2, axis=-1, keepdims=True) # Cross-device aggregation if axis_name is provided. if self.axis_name is not None: mean_sq = jax.lax.pmean( mean_sq, self.axis_name, axis_index_groups=self.axis_index_groups ) # Compute RMS. rms = jnp.sqrt(mean_sq + self.epsilon) # Normalize. normalized = inputs / rms # Scale. if self.use_scale: scale, params = self.get_param( params, 'scale', (features,), self.scale_init, rng=self.rng ) normalized = normalized * scale return normalized, params
[docs] class BatchNorm(bx.Module): """Batch Normalization. Normalizes over the batch and spatial dimensions, maintaining running statistics for inference. During training, computes batch statistics and updates the running averages. During inference, uses the stored running statistics. The input is expected to have shape ``(batch, *spatial_dims, features)``. Normalization is applied over all axes except the last (features) axis. Example:: bn = BatchNorm(graph.child('bn')) # Training: uses batch statistics and updates running stats. y, params = bn(params, x, is_training=True) # Inference: uses running statistics. y, params = bn(params, x, is_training=False) """ def __init__( self, graph: bx.Graph, momentum: float = 0.9, epsilon: float = 1e-5, use_scale: bool = True, use_bias: bool = True, scale_init: Initializer = jax.nn.initializers.ones, bias_init: Initializer = jax.nn.initializers.zeros, axis_name: str | None = None, axis_index_groups: Sequence[Sequence[int]] | None = None, rng: bx.Rng | None = None, ) -> None: """Initializes the BatchNorm module. Args: graph: The graph node for this module. momentum: Momentum for the exponential moving average of running stats. running_stat = momentum * running_stat + (1.0 - momentum) * batch_stat. epsilon: Small constant for numerical stability. use_scale: Whether to use a learnable scale parameter (gamma). use_bias: Whether to use a learnable bias parameter (beta). scale_init: Initializer for scale. bias_init: Initializer for bias. axis_name: The axis name used to combine statistics from multiple devices. See jax.shard_map for a description of axis names. axis_index_groups: Groups of axis indices within that named axis representing subsets of devices to reduce over. rng: Rng module. Required if using stochastic initializers for scale/bias. """ super().__init__(graph) self.momentum = momentum self.epsilon = epsilon self.use_scale = use_scale self.use_bias = use_bias self.scale_init = scale_init self.bias_init = bias_init self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.rng = rng def __call__( self, params: bx.Params, inputs: jax.Array, is_training: bool = True, ) -> tuple[jax.Array, bx.Params]: """Applies batch normalization. Args: params: The parameters container. inputs: The input array with shape ``(batch, *spatial_dims, features)``. is_training: If True, uses batch statistics and updates running stats. If False, uses stored running statistics (inference). Returns: A tuple (normalized_output, params). """ features = inputs.shape[-1] # Axes to reduce over: all except the last (features) axis. reduce_axes = tuple(range(inputs.ndim - 1)) # Get or create running statistics (non-trainable). running_mean, params = self.get_param( params=params, name='running_mean', shape=(features,), init=jax.nn.initializers.zeros, trainable=False, rng=self.rng, ) running_var, params = self.get_param( params=params, name='running_var', shape=(features,), init=jax.nn.initializers.ones, trainable=False, rng=self.rng, ) if is_training: # Compute batch statistics. mean = jnp.mean(inputs, axis=reduce_axes) var = jnp.var(inputs, axis=reduce_axes) # Cross-device aggregation if axis_name is provided. if self.axis_name is not None: mean = jax.lax.pmean( mean, self.axis_name, axis_index_groups=self.axis_index_groups ) var = jax.lax.pmean( var, self.axis_name, axis_index_groups=self.axis_index_groups ) # Update running statistics with exponential moving average. # stop_gradient prevents backprop through running stats. new_running_mean = self.momentum * running_mean + ( 1 - self.momentum ) * jax.lax.stop_gradient(mean) new_running_var = self.momentum * running_var + ( 1 - self.momentum ) * jax.lax.stop_gradient(var) # Store updated running statistics. params = self.set_param( params=params, name='running_mean', value=new_running_mean ) params = self.set_param( params=params, name='running_var', value=new_running_var ) else: # Use running statistics for inference. mean = running_mean var = running_var # Normalize. normalized = (inputs - mean) / jnp.sqrt(var + self.epsilon) # Scale and shift. if self.use_scale: scale, params = self.get_param( params, 'scale', (features,), self.scale_init, rng=self.rng ) normalized = normalized * scale if self.use_bias: bias, params = self.get_param( params, 'bias', (features,), self.bias_init, rng=self.rng ) normalized = normalized + bias return normalized, params
def _normalize_tuple(x: int | Sequence[int], n: int) -> tuple[int, ...]: """Converts int or sequence to a tuple of length n.""" if isinstance(x, int): return (x,) * n return tuple(x)
[docs] class Conv(bx.Module): """General N-dimensional convolution layer. The number of spatial dimensions is inferred from ``kernel_size``: * 1-tuple or int: 1D convolution (single spatial dimension). * 2-tuple: 2D convolution (height, width). * 3-tuple: 3D convolution (depth, height, width). Supports arbitrary batch dimensions (0 or more). Uses channels-last convention: ``(*batch, *spatial_dims, channels)``. Example:: # 2D convolution for images (NHWC format). conv = Conv( graph.child('conv'), kernel_size=(3, 3), output_channels=64, rng=rng, ) y, params = conv(params, x) # x: [batch, height, width, channels] # 1D convolution for sequences (NLC format). conv = Conv( graph.child('conv'), kernel_size=3, output_channels=64, rng=rng ) y, params = conv(params, x) # x: [batch, length, channels] # Unbatched input. conv = Conv( graph.child('conv'), kernel_size=(3, 3), output_channels=64, rng=rng, ) y, params = conv(params, x) # x: [height, width, channels] """ def __init__( self, graph: bx.Graph, kernel_size: int | Sequence[int], output_channels: int, rng: bx.Rng | None, strides: int | Sequence[int] = 1, padding: PaddingLike = 'SAME', input_dilation: int | Sequence[int] = 1, kernel_dilation: int | Sequence[int] = 1, feature_group_count: int = 1, use_bias: bool = True, kernel_init: Initializer = jax.nn.initializers.lecun_normal(), bias_init: Initializer = jax.nn.initializers.zeros, kernel_metadata: dict[str, Any] | None = None, bias_metadata: dict[str, Any] | None = None, ) -> None: """Initializes the Conv module. Args: graph: The graph node for this module. kernel_size: Shape of the convolutional kernel as a tuple, determining the number of spatial dimensions (e.g., (3, 3) for 2D conv). For 1D convolution, either an int or a 1-tuple can be used. output_channels: Number of output channels. rng: Rng module for random initialization. If kernel_init and bias_init are constant, Rng is not required (but still recommended). strides: Stride of the convolution. An int is broadcast to all spatial dimensions. padding: Padding mode. Either 'SAME', 'VALID', or a sequence of (low, high) padding pairs for each spatial dimension. input_dilation: Dilation of the input (transposed convolution). kernel_dilation: Dilation of the kernel (atrous convolution). feature_group_count: Number of feature groups for grouped convolution. Set to input_channels for depthwise convolution. use_bias: Whether to add a learnable bias. kernel_init: Initializer for the kernel. bias_init: Initializer for the bias. kernel_metadata: Optional metadata for the kernel parameter. bias_metadata: Optional metadata for the bias parameter. """ super().__init__(graph) self.kernel_size = ( (kernel_size,) if isinstance(kernel_size, int) else tuple(kernel_size) ) self.output_channels = output_channels self.rng = rng self.strides = strides self.padding = padding self.input_dilation = input_dilation self.kernel_dilation = kernel_dilation self.feature_group_count = feature_group_count self.use_bias = use_bias self.kernel_init = kernel_init self.bias_init = bias_init self.kernel_metadata = kernel_metadata self.bias_metadata = bias_metadata def __call__( self, params: bx.Params, inputs: jax.Array, precision: jax.lax.Precision | None = None, ) -> tuple[jax.Array, bx.Params]: """Applies the convolution. Args: params: The parameters container. inputs: Input array with shape ``(*batch, *spatial_dims, input_channels)``. Supports arbitrary batch dimensions (0 or more). precision: Optional precision for the convolution. Returns: A tuple (output, params), where output has shape ``(*batch, *out_spatial, output_channels)``. Raises: ValueError: If input has fewer dimensions than needed for conv. """ num_spatial = len(self.kernel_size) min_rank = num_spatial + 1 # spatial + channels (batch is optional) if inputs.ndim < min_rank: raise ValueError( f'Expected input rank >= {min_rank} for {num_spatial}D conv ' f'(at least {num_spatial} spatial dims + 1 channel dim), ' f'got {inputs.ndim}.' ) # Handle arbitrary batch dimensions by reshaping. # (*batch, *spatial, channels) -> (combined_batch, *spatial, channels) batch_shape = inputs.shape[: -num_spatial - 1] spatial_shape = inputs.shape[-num_spatial - 1 : -1] input_channels = inputs.shape[-1] batch_size = math.prod(batch_shape) if batch_shape else 1 inputs_flat = inputs.reshape( (batch_size,) + spatial_shape + (input_channels,) ) if input_channels % self.feature_group_count != 0: raise ValueError( f'input_channels ({input_channels}) must be divisible by ' f'feature_group_count ({self.feature_group_count}).' ) # Kernel shape: (*kernel_size, input_channels // groups, output_channels) kernel_shape = self.kernel_size + ( input_channels // self.feature_group_count, self.output_channels, ) kernel, params = self.get_param( params, 'kernel', kernel_shape, self.kernel_init, metadata=self.kernel_metadata, rng=self.rng, ) # Normalize strides and dilations to tuples. strides = _normalize_tuple(self.strides, num_spatial) input_dilation = _normalize_tuple(self.input_dilation, num_spatial) kernel_dilation = _normalize_tuple(self.kernel_dilation, num_spatial) # Build dimension numbers for channels-last format. # (N, *spatial, C) -> lax expects (N, C, *spatial) # We use dimension_numbers to avoid transposing. spatial_dims = tuple(range(1, num_spatial + 1)) lhs_spec = (0, num_spatial + 1) + spatial_dims # (N, C, *spatial) rhs_spec = (num_spatial + 1, num_spatial) + tuple(range(num_spatial)) out_spec = (0, num_spatial + 1) + spatial_dims dimension_numbers = jax.lax.ConvDimensionNumbers( lhs_spec=lhs_spec, rhs_spec=rhs_spec, out_spec=out_spec ) # Apply convolution. outputs = jax.lax.conv_general_dilated( inputs_flat, kernel, window_strides=strides, padding=self.padding, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=precision, ) # Add bias. if self.use_bias: bias, params = self.get_param( params, 'bias', (self.output_channels,), self.bias_init, metadata=self.bias_metadata, rng=self.rng, ) outputs = outputs + bias # Reshape output back to original batch shape. # (combined_batch, *out_spatial, out_channels) # -> (*batch, *out_spatial, out_channels) out_spatial = outputs.shape[1:-1] outputs = outputs.reshape( batch_shape + out_spatial + (self.output_channels,) ) return outputs, params
[docs] class ConvTranspose(bx.Module): """General N-dimensional transposed convolution layer. Also known as deconvolution or fractionally-strided convolution. The number of spatial dimensions is inferred from ``kernel_size``: * 1-tuple or int: 1D convolution (single spatial dimension). * 2-tuple: 2D convolution (height, width). * 3-tuple: 3D convolution (depth, height, width). Supports arbitrary batch dimensions (0 or more). Uses channels-last convention: ``(*batch, *spatial_dims, channels)``. Example:: # 2D transposed convolution for images (NHWC format). conv_t = ConvTranspose( graph.child('conv_t'), kernel_size=(3, 3), output_channels=3, rng=rng, ) y, params = conv_t(params, x) # x: [batch, height, width, channels] """ def __init__( self, graph: bx.Graph, kernel_size: int | Sequence[int], output_channels: int, rng: bx.Rng | None, strides: int | Sequence[int] = 1, padding: PaddingLike = 'SAME', kernel_dilation: int | Sequence[int] = 1, feature_group_count: int = 1, use_bias: bool = True, kernel_init: Initializer = jax.nn.initializers.lecun_normal(), bias_init: Initializer = jax.nn.initializers.zeros, kernel_metadata: dict[str, Any] | None = None, bias_metadata: dict[str, Any] | None = None, ) -> None: """Initializes the ConvTranspose module. Args: graph: The graph node for this module. kernel_size: Shape of the convolutional kernel as a tuple, determining the number of spatial dimensions (e.g., (3, 3) for 2D conv). For 1D convolution, either an int or a 1-tuple can be used. output_channels: Number of output channels. rng: Rng module for random initialization. If kernel_init and bias_init are constant, Rng is not required (but still recommended). strides: Stride of the convolution. An int is broadcast to all spatial dimensions. padding: Padding mode. Either 'SAME', 'VALID', or a sequence of (low, high) padding pairs for each spatial dimension. kernel_dilation: Dilation of the kernel (atrous convolution). feature_group_count: Number of feature groups for grouped convolution. Set to input_channels for depthwise convolution. use_bias: Whether to add a learnable bias. kernel_init: Initializer for the kernel. bias_init: Initializer for the bias. kernel_metadata: Optional metadata for the kernel parameter. bias_metadata: Optional metadata for the bias parameter. """ super().__init__(graph) self.kernel_size = ( (kernel_size,) if isinstance(kernel_size, int) else tuple(kernel_size) ) self.output_channels = output_channels self.rng = rng self.strides = strides self.padding = padding self.kernel_dilation = kernel_dilation self.feature_group_count = feature_group_count self.use_bias = use_bias self.kernel_init = kernel_init self.bias_init = bias_init self.kernel_metadata = kernel_metadata self.bias_metadata = bias_metadata def __call__( self, params: bx.Params, inputs: jax.Array, precision: jax.lax.Precision | None = None, ) -> tuple[jax.Array, bx.Params]: """Applies the transposed convolution. Args: params: The parameters container. inputs: Input array with shape ``(*batch, *spatial_dims, input_channels)``. Supports arbitrary batch dimensions (0 or more). precision: Optional precision for the convolution. Returns: A tuple (output, params), where output has shape ``(*batch, *out_spatial, output_channels)``. Raises: ValueError: If input has fewer dimensions than needed for conv. """ num_spatial = len(self.kernel_size) min_rank = num_spatial + 1 # spatial + channels (batch is optional) if inputs.ndim < min_rank: raise ValueError( f'Expected input rank >= {min_rank} for {num_spatial}D conv_transpose' f' (at least {num_spatial} spatial dims + 1 channel dim), ' f'got {inputs.ndim}.' ) # Handle arbitrary batch dimensions by reshaping. # (*batch, *spatial, channels) -> (combined_batch, *spatial, channels) batch_shape = inputs.shape[: -num_spatial - 1] spatial_shape = inputs.shape[-num_spatial - 1 : -1] input_channels = inputs.shape[-1] batch_size = math.prod(batch_shape) if batch_shape else 1 inputs_flat = inputs.reshape( (batch_size,) + spatial_shape + (input_channels,) ) if input_channels % self.feature_group_count != 0: raise ValueError( f'input_channels ({input_channels}) must be divisible by ' f'feature_group_count ({self.feature_group_count}).' ) # Kernel shape for ConvTranspose: # (*kernel_size, output_channels, input_channels // groups) kernel_shape = self.kernel_size + ( self.output_channels, input_channels // self.feature_group_count, ) kernel, params = self.get_param( params, 'kernel', kernel_shape, self.kernel_init, metadata=self.kernel_metadata, rng=self.rng, ) # Normalize strides and dilations to tuples. strides = _normalize_tuple(self.strides, num_spatial) kernel_dilation = _normalize_tuple(self.kernel_dilation, num_spatial) # Build dimension numbers for channels-last format. # Input: (N, *spatial, C_in) -> lax expects (N, C_in, *spatial) # Kernel: (*kernel_spatial, C_out, C_in_per_group) # Output: (N, *spatial, C_out) -> lax expects (N, C_out, *spatial) spatial_dims = tuple(range(1, num_spatial + 1)) lhs_spec = (0, num_spatial + 1) + spatial_dims # (N, C_in, *spatial) # For conv_transpose, rhs_spec maps kernel to (out_c, in_c, *spatial) # We want O -> C_out (index num_spatial) # We want I -> C_in (index num_spatial + 1) rhs_spec = (num_spatial, num_spatial + 1) + tuple(range(num_spatial)) out_spec = (0, num_spatial + 1) + spatial_dims # (N, C_out, *spatial) dimension_numbers = jax.lax.ConvDimensionNumbers( lhs_spec=lhs_spec, rhs_spec=rhs_spec, out_spec=out_spec ) # Apply transposed convolution. outputs = jax.lax.conv_transpose( inputs_flat, kernel, strides=strides, padding=self.padding, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, precision=precision, ) # Add bias. if self.use_bias: bias, params = self.get_param( params, 'bias', (self.output_channels,), self.bias_init, metadata=self.bias_metadata, rng=self.rng, ) outputs = outputs + bias # Reshape output back to original batch shape. # (combined_batch, *out_spatial, out_channels) # -> (*batch, *out_spatial, out_channels) out_spatial = outputs.shape[1:-1] outputs = outputs.reshape( batch_shape + out_spatial + (self.output_channels,) ) return outputs, params
[docs] def max_pool( inputs: jax.Array, window_shape: int | Sequence[int], strides: int | Sequence[int] | None = None, padding: PaddingLike = 'VALID', ) -> jax.Array: """Applies max pooling over spatial dimensions. The number of spatial dimensions is inferred from ``window_shape``: * 1-tuple or int: 1D pooling (single spatial dimension). * 2-tuple: 2D pooling (height, width). * 3-tuple: 3D pooling (depth, height, width). Supports arbitrary batch dimensions (0 or more). Uses channels-last convention: ``(*batch, *spatial_dims, channels)``. Args: inputs: Input array with shape ``(*batch, *spatial_dims, channels)``. window_shape: Shape of the pooling window as a tuple, determining the number of spatial dimensions. For 1D pooling, an int can be used. strides: Stride of the pooling. If None, uses window_shape (no overlap). padding: Padding mode. Either 'SAME', 'VALID', or a sequence of (low, high) padding pairs for each spatial dimension. Returns: Pooled output array. Example:: # 2x2 max pooling with stride 2. y = max_pool(x, window_shape=(2, 2), strides=2) """ window = ( (window_shape,) if isinstance(window_shape, int) else tuple(window_shape) ) num_spatial = len(window) min_rank = num_spatial + 1 # spatial + channels (batch optional) if inputs.ndim < min_rank: raise ValueError( f'Expected input rank >= {min_rank} for {num_spatial}D pooling ' f'(at least {num_spatial} spatial dims + 1 channel dim), ' f'got {inputs.ndim}.' ) # Handle arbitrary batch dimensions by reshaping. batch_shape = inputs.shape[: -num_spatial - 1] spatial_shape = inputs.shape[-num_spatial - 1 : -1] channels = inputs.shape[-1] batch_size = math.prod(batch_shape) if batch_shape else 1 inputs_flat = inputs.reshape((batch_size,) + spatial_shape + (channels,)) strides_tuple = ( window if strides is None else _normalize_tuple(strides, num_spatial) ) # jax.lax.reduce_window expects window and strides for all dims. full_window = (1,) + window + (1,) full_strides = (1,) + strides_tuple + (1,) # Normalize padding for all dimensions (including batch and channel). if isinstance(padding, str): full_padding = padding else: full_padding = ((0, 0),) + tuple(padding) + ((0, 0),) outputs = jax.lax.reduce_window( inputs_flat, init_value=-jnp.inf, computation=jax.lax.max, window_dimensions=full_window, window_strides=full_strides, padding=full_padding, ) # Reshape output back to original batch shape. out_spatial = outputs.shape[1:-1] return outputs.reshape(batch_shape + out_spatial + (channels,))
[docs] def min_pool( inputs: jax.Array, window_shape: int | Sequence[int], strides: int | Sequence[int] | None = None, padding: PaddingLike = 'VALID', ) -> jax.Array: """Applies min pooling over spatial dimensions. The number of spatial dimensions is inferred from ``window_shape``: * 1-tuple or int: 1D pooling (single spatial dimension). * 2-tuple: 2D pooling (height, width). * 3-tuple: 3D pooling (depth, height, width). Supports arbitrary batch dimensions (0 or more). Uses channels-last convention: ``(*batch, *spatial_dims, channels)``. Args: inputs: Input array with shape ``(*batch, *spatial_dims, channels)``. window_shape: Shape of the pooling window as a tuple, determining the number of spatial dimensions. For 1D pooling, an int can be used. strides: Stride of the pooling. If None, uses window_shape (no overlap). padding: Padding mode. Either 'SAME', 'VALID', or a sequence of (low, high) padding pairs for each spatial dimension. Returns: Pooled output array. """ window = ( (window_shape,) if isinstance(window_shape, int) else tuple(window_shape) ) num_spatial = len(window) min_rank = num_spatial + 1 if inputs.ndim < min_rank: raise ValueError( f'Expected input rank >= {min_rank} for {num_spatial}D pooling ' f'(at least {num_spatial} spatial dims + 1 channel dim), ' f'got {inputs.ndim}.' ) batch_shape = inputs.shape[: -num_spatial - 1] spatial_shape = inputs.shape[-num_spatial - 1 : -1] channels = inputs.shape[-1] batch_size = math.prod(batch_shape) if batch_shape else 1 inputs_flat = inputs.reshape((batch_size,) + spatial_shape + (channels,)) strides_tuple = ( window if strides is None else _normalize_tuple(strides, num_spatial) ) full_window = (1,) + window + (1,) full_strides = (1,) + strides_tuple + (1,) if isinstance(padding, str): full_padding = padding else: full_padding = ((0, 0),) + tuple(padding) + ((0, 0),) outputs = jax.lax.reduce_window( inputs_flat, init_value=jnp.inf, computation=jax.lax.min, window_dimensions=full_window, window_strides=full_strides, padding=full_padding, ) out_spatial = outputs.shape[1:-1] return outputs.reshape(batch_shape + out_spatial + (channels,))
[docs] def avg_pool( inputs: jax.Array, window_shape: int | Sequence[int], strides: int | Sequence[int] | None = None, padding: PaddingLike = 'VALID', ) -> jax.Array: """Applies average pooling over spatial dimensions. The number of spatial dimensions is inferred from ``window_shape``: * 1-tuple or int: 1D pooling (single spatial dimension). * 2-tuple: 2D pooling (height, width). * 3-tuple: 3D pooling (depth, height, width). Supports arbitrary batch dimensions (0 or more). Uses channels-last convention: ``(*batch, *spatial_dims, channels)``. Note: When ``padding='SAME'`` the average is computed over the valid (non-padded) pixels in the window, ignoring the zeros added by padding. Args: inputs: Input array with shape ``(*batch, *spatial_dims, channels)``. window_shape: Shape of the pooling window as a tuple, determining the number of spatial dimensions. For 1D pooling, an int can be used. strides: Stride of the pooling. If None, uses window_shape (no overlap). padding: Padding mode. Either 'SAME', 'VALID', or a sequence of (low, high) padding pairs for each spatial dimension. Returns: Pooled output array. Example:: # 2x2 average pooling with stride 2. y = avg_pool(x, window_shape=(2, 2), strides=2) """ window = ( (window_shape,) if isinstance(window_shape, int) else tuple(window_shape) ) num_spatial = len(window) min_rank = num_spatial + 1 if inputs.ndim < min_rank: raise ValueError( f'Expected input rank >= {min_rank} for {num_spatial}D pooling ' f'(at least {num_spatial} spatial dims + 1 channel dim), ' f'got {inputs.ndim}.' ) batch_shape = inputs.shape[: -num_spatial - 1] spatial_shape = inputs.shape[-num_spatial - 1 : -1] channels = inputs.shape[-1] batch_size = math.prod(batch_shape) if batch_shape else 1 inputs_flat = inputs.reshape((batch_size,) + spatial_shape + (channels,)) strides_tuple = ( window if strides is None else _normalize_tuple(strides, num_spatial) ) full_window = (1,) + window + (1,) full_strides = (1,) + strides_tuple + (1,) if isinstance(padding, str): full_padding = padding else: full_padding = ((0, 0),) + tuple(padding) + ((0, 0),) # Sum pooling. pooled_sum = jax.lax.reduce_window( inputs_flat, init_value=0.0, computation=jax.lax.add, window_dimensions=full_window, window_strides=full_strides, padding=full_padding, ) # For 'VALID' padding, all windows are full so we divide by a constant. # For 'SAME' or explicit padding, we count valid elements per window. if padding == 'VALID': window_size = math.prod(window) outputs = pooled_sum / window_size else: # Count valid elements in each window. Crucial for 'SAME' padding where # boundary windows may have fewer valid elements. mask = jnp.ones_like(inputs_flat) window_counts = jax.lax.reduce_window( mask, init_value=0.0, computation=jax.lax.add, window_dimensions=full_window, window_strides=full_strides, padding=full_padding, ) outputs = pooled_sum / window_counts out_spatial = outputs.shape[1:-1] return outputs.reshape(batch_shape + out_spatial + (channels,))