"""Standard neural network building blocks.
This module provides a collection of pre-built, functional layers ready for use.
Included Layers:
- **Core:** `Embed`, `Linear`
- **Convolution:** `Conv`, `ConvTranspose`
- **Recurrent:** `LSTM`, `GRU`
- **Normalization:** `LayerNorm`, `RMSNorm`, `BatchNorm`
- **Regularization:** `Dropout`
- **Containers:** `Sequential`
- **Pooling:** `max_pool`, `min_pool`, `avg_pool`
Blocks are usually subclasses of `bx.Module` which strictly follows the
`(params, inputs) -> (outputs, params)` functional signature.
"""
import math
from typing import Any, Callable, NamedTuple, Sequence
import jax
import jax.numpy as jnp
from . import interfaces as bx
Initializer = jax.nn.initializers.Initializer
PaddingLike = str | Sequence[tuple[int, int]]
[docs]
class Embed(bx.Module):
"""Embedding layer that maps integer indices to dense vectors.
Supports weight tying for language models via the `attend` method, which
applies the transpose of the embedding matrix (useful for output projections).
Example:
embed = Embed(
graph.child('embed'), num_embeddings=10000, embedding_size=512, rng=rng
)
# Forward pass: indices -> embeddings
embeddings, params = embed(params, token_ids)
# Weight-tied output projection: hidden -> logits
logits, params = embed.attend(params, hidden_states)
"""
def __init__(
self,
graph: bx.Graph,
num_embeddings: int,
embedding_size: int,
rng: bx.Rng | None,
embedding_init: Initializer = jax.nn.initializers.variance_scaling(
1.0, 'fan_in', 'normal', out_axis=0
),
embedding_metadata: dict[str, Any] | None = None,
) -> None:
"""Initializes the Embed module.
Args:
graph: The graph node for this module.
num_embeddings: Size of the vocabulary (number of unique tokens).
embedding_size: Dimensionality of the embedding vectors.
rng: Rng module for random initialization. If embedding_init is constant,
Rng is not required (but still recommended).
embedding_init: Initializer for the embedding matrix.
embedding_metadata: Optional metadata for the embedding parameter.
"""
super().__init__(graph)
self.num_embeddings = num_embeddings
self.embedding_size = embedding_size
self.rng = rng
self.embedding_init = embedding_init
self.embedding_metadata = embedding_metadata
def __call__(
self,
params: bx.Params,
indices: jax.Array,
) -> tuple[jax.Array, bx.Params]:
"""Looks up embeddings for the given indices.
Args:
params: The parameters container.
indices: Integer array of token indices. Shape [...].
Returns:
A tuple (embeddings, params). Embeddings have shape [..., embedding_size].
"""
embedding_matrix, params = self.get_param(
params=params,
name='embedding',
shape=(self.num_embeddings, self.embedding_size),
init=self.embedding_init,
metadata=self.embedding_metadata,
rng=self.rng,
)
return embedding_matrix[indices], params
[docs]
def attend(
self,
params: bx.Params,
inputs: jax.Array,
) -> tuple[jax.Array, bx.Params]:
"""Applies the transpose of the embedding matrix (for weight tying).
This is commonly used in language models where the output projection
shares weights with the input embedding.
Args:
params: The parameters container.
inputs: Input array of shape [..., embedding_size].
Returns:
A tuple (logits, params). Logits have shape [..., num_embeddings].
"""
embedding_matrix, params = self.get_param(
params=params,
name='embedding',
shape=(self.num_embeddings, self.embedding_size),
init=self.embedding_init,
metadata=self.embedding_metadata,
rng=self.rng,
)
# inputs @ embedding_matrix.T
return jnp.dot(inputs, embedding_matrix.T), params
[docs]
class Linear(bx.Module):
"""A standard linear transformation layer.
Computes `output = input @ kernel + bias`.
Supports model parallelism via metadata. Example for sharding weights:
linear = Linear(
graph.child('linear'),
output_size=1024,
rng=rng,
kernel_metadata={'sharding': (None, 'model')}, # Shard output dim
bias_metadata={'sharding': ('model',)},
)
"""
def __init__(
self,
graph: bx.Graph,
output_size: int,
rng: bx.Rng | None,
use_bias: bool = True,
kernel_init: Initializer = jax.nn.initializers.lecun_normal(),
bias_init: Initializer = jax.nn.initializers.zeros,
kernel_metadata: dict[str, Any] | None = None,
bias_metadata: dict[str, Any] | None = None,
) -> None:
"""Initializes the Linear module.
Args:
graph: The graph node for this module.
output_size: The dimensionality of the output features.
rng: Rng module for random initialization. If kernel_init and bias_init
are constant, Rng is not required (but still recommended).
use_bias: Whether to add a learnable bias vector.
kernel_init: Initializer for the weight matrix.
bias_init: Initializer for the bias vector.
kernel_metadata: Optional metadata for the kernel parameter. Common keys:
- 'sharding': tuple like (None, 'model') for model parallelism
bias_metadata: Optional metadata for the bias parameter. Common keys:
- 'sharding': tuple like ('model',) for model parallelism
"""
super().__init__(graph)
self.output_size = output_size
self.rng = rng
self.use_bias = use_bias
self.kernel_init = kernel_init
self.bias_init = bias_init
self.kernel_metadata = kernel_metadata
self.bias_metadata = bias_metadata
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
precision: jax.lax.Precision | None = None,
) -> tuple[jax.Array, bx.Params]:
"""Applies the linear transformation.
Args:
params: The parameters container.
inputs: The input array. Must have at least one dimension.
Shape should be [..., input_features].
precision: Optional precision for the matrix multiplication.
Returns:
A tuple (output, params). The output has shape [..., output_size].
Raises:
ValueError: If the input is a scalar (rank 0).
"""
if not inputs.shape:
raise ValueError('Input must not be scalar.')
input_size = inputs.shape[-1]
kernel, params = self.get_param(
params,
'kernel',
(input_size, self.output_size),
self.kernel_init,
metadata=self.kernel_metadata,
rng=self.rng,
)
outputs = jnp.dot(inputs, kernel, precision=precision)
if self.use_bias:
bias, params = self.get_param(
params,
'bias',
(self.output_size,),
self.bias_init,
metadata=self.bias_metadata,
rng=self.rng,
)
bias = jnp.broadcast_to(bias, outputs.shape)
outputs = outputs + bias
return outputs, params
[docs]
class Sequential(bx.Module):
"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
Example:
mlp = Sequential(
graph.child('mlp'),
[
bx.Linear(graph.child('l1'), 32),
jax.nn.relu,
bx.Linear(graph.child('l2'), 10),
],
)
y, params = mlp(params, x)
"""
def __init__(
self,
graph: bx.Graph,
layers: Sequence[bx.Module | Callable[[jax.Array], jax.Array]],
) -> None:
"""Initializes the Sequential module.
Args:
graph: The graph node for this module.
layers: A list of blox Modules or callables.
If a layer is a blox Module, it must accept (params, inputs) and return
(output, params).
If a layer is a simple callable (like jax.nn.relu), it must accept
inputs and return output.
"""
super().__init__(graph)
self.layers = layers
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
) -> tuple[jax.Array, bx.Params]:
"""Applies the sequential model.
Args:
params: The parameters container.
inputs: The input array.
Returns:
A tuple (output, params).
"""
x = inputs
for layer in self.layers:
if isinstance(layer, bx.Module):
x, params = layer(params, x)
else:
# Assume it's a pure activation function like jax.nn.relu.
x = layer(x)
return x, params
[docs]
class LSTMState(NamedTuple):
"""Holds the hidden and cell states for an LSTM."""
hidden: jax.Array
cell: jax.Array
[docs]
class LSTM(bx.RecurrenceBase[jax.Array, LSTMState, jax.Array, jax.Array]):
r"""Long Short-Term Memory (LSTM) Recurrent Neural Network.
The mathematical definition of the cell is as follows:
.. math::
i = \\sigma(W_{ii} x + W_{hi} h + b_{hi}) \\\\
f = \\sigma(W_{if} x + W_{hf} h + b_{hf}) \\\\
g = \\tanh(W_{ig} x + W_{hg} h + b_{hg}) \\\\
o = \\sigma(W_{io} x + W_{ho} h + b_{ho}) \\\\
c' = f * c + i * g \\\\
h' = o * \\tanh(c')
where x is the input, h is the output of the previous time step, and c is
the memory.
This module implements a standard LSTM cell. It inherits from RecurrenceBase,
automatically providing support for both single-step execution (`__call__`)
and efficient sequence processing (`apply` with scanning).
Example:
lstm = LSTM(graph.child('lstm'), hidden_size=128, rng=rng)
# Initialize state first:
state, params = lstm.initial_state(params, inputs)
# Single-step processing (e.g., for interactive use):
(outputs, state), params = lstm(params, inputs, state)
# Sequence processing (uses jax.lax.scan for efficiency):
(outputs, state), params = lstm.apply(params, inputs_sequence, state)
"""
def __init__(
self,
graph: bx.Graph,
hidden_size: int,
rng: bx.Rng | None,
is_static: bool = False,
) -> None:
"""Initializes the LSTM.
Args:
graph: The graph node for this module.
hidden_size: The dimensionality of the hidden and cell states.
rng: Rng module for random initialization.
is_static: If True, uses Python loops for sequence processing.
If False, uses jax.lax.scan (default).
"""
super().__init__(graph, is_static)
self.hidden_size = hidden_size
# Using a single Linear layer to project inputs to the 4 gates (i, g, f, o).
self.gates = Linear(
graph.child('gates'), output_size=4 * hidden_size, rng=rng
)
[docs]
def initial_state(
self, params: bx.Params, inputs: jax.Array
) -> tuple[LSTMState, bx.Params]:
"""Creates the initial zero state.
Args:
params: The parameters container.
inputs: The input array, used to infer the batch size (dimension 0).
Returns:
A tuple (LSTMState, params), where both hidden and cell states are zeros.
"""
batch_size = inputs.shape[0]
return (
LSTMState(
hidden=jnp.zeros((batch_size, self.hidden_size)),
cell=jnp.zeros((batch_size, self.hidden_size)),
),
params,
)
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
prev_state: LSTMState | None,
is_reset: jax.Array | None = None,
is_training: bool = True,
) -> tuple[tuple[jax.Array, LSTMState], bx.Params]:
"""Computes a single step of the LSTM recurrence.
Args:
params: The parameters container.
inputs: The input at the current time step. Shape [Batch, input_size].
prev_state: The previous LSTM state. Must not be None.
is_reset: Optional boolean array [Batch]. If True for a batch element,
the state is reset to zero *before* computing the step.
is_training: Unused.
Returns:
A nested tuple ((output, new_state), params).
The output of the LSTM is the hidden state.
Raises:
ValueError: If prev_state is None.
"""
del is_training # Currently unused.
if prev_state is None:
raise ValueError('The LSTM __call__ method requires a valid prev_state.')
# Apply reset mask if provided.
prev_state = self.maybe_reset_state(params, prev_state, inputs, is_reset)
prev_h, prev_c = prev_state.hidden, prev_state.cell
# Concatenate input and previous hidden state.
x_and_h = jnp.concatenate([inputs, prev_h], axis=-1)
# Project to gates.
gated, params = self.gates(params, x_and_h)
# Split into input, gate, forget, and output components.
i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
# Apply activations.
f = jax.nn.sigmoid(f)
c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g)
h = jax.nn.sigmoid(o) * jnp.tanh(c)
new_state = LSTMState(hidden=h, cell=c)
# Output is h, state is (h, c).
return (h, new_state), params
[docs]
class GRUState(NamedTuple):
"""Holds the hidden state for a GRU."""
hidden: jax.Array
[docs]
class GRU(bx.RecurrenceBase[jax.Array, GRUState, jax.Array, jax.Array]):
r"""Gated Recurrent Unit (GRU).
The mathematical definition of the cell is as follows:
.. math::
r = \\sigma(W_{ir} x + W_{hr} h + b_{hr}) \\\\
z = \\sigma(W_{iz} x + W_{hz} h + b_{hz}) \\\\
n = \\tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\\\
h' = (1 - z) * n + z * h
where x is the input and h is the output of the previous time step.
Example:
gru = GRU(graph.child('gru'), hidden_size=128, rng=rng)
state, params = gru.initial_state(params, inputs)
(outputs, state), params = gru(params, inputs, state)
"""
def __init__(
self,
graph: bx.Graph,
hidden_size: int,
rng: bx.Rng | None,
is_static: bool = False,
) -> None:
"""Initializes the GRU.
Args:
graph: The graph node for this module.
hidden_size: The dimensionality of the hidden state.
rng: Rng module for random initialization.
is_static: If True, uses Python loops. If False, uses jax.lax.scan.
"""
super().__init__(graph, is_static)
self.hidden_size = hidden_size
# We use two linear layers:
# Update and reset gates (z, r) computed from x and h.
# Candidate hidden state (h_tilde) computed from x and (r * h).
self.gates = Linear(
graph.child('gates'), output_size=2 * hidden_size, rng=rng
)
self.candidate = Linear(
graph.child('candidate'), output_size=hidden_size, rng=rng
)
[docs]
def initial_state(
self, params: bx.Params, inputs: jax.Array
) -> tuple[GRUState, bx.Params]:
"""Creates the initial zero state.
Args:
params: The parameters container.
inputs: The input array, used to infer the batch size (dimension 0).
Returns:
A tuple (GRUState, params).
"""
batch_size = inputs.shape[0]
return GRUState(hidden=jnp.zeros((batch_size, self.hidden_size))), params
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
prev_state: GRUState | None,
is_reset: jax.Array | None = None,
is_training: bool = True,
) -> tuple[tuple[jax.Array, GRUState], bx.Params]:
"""Computes a single step of the GRU recurrence.
Args:
params: The parameters container.
inputs: The input at the current time step. Shape [Batch, input_size].
prev_state: The previous GRU state. Must not be None.
is_reset: Optional boolean array [Batch].
is_training: Unused.
Returns:
A nested tuple ((output, new_state), params).
The output of the GRU is the hidden state.
"""
del is_training
if prev_state is None:
raise ValueError('The GRU __call__ method requires a valid prev_state.')
prev_state = self.maybe_reset_state(params, prev_state, inputs, is_reset)
prev_h = prev_state.hidden
x_and_h = jnp.concatenate([inputs, prev_h], axis=-1)
gates_out, params = self.gates(params, x_and_h)
z, r = jnp.split(gates_out, indices_or_sections=2, axis=-1)
z = jax.nn.sigmoid(z)
r = jax.nn.sigmoid(r)
r_h = r * prev_h
x_and_rh = jnp.concatenate([inputs, r_h], axis=-1)
h_tilde, params = self.candidate(params, x_and_rh)
h_tilde = jnp.tanh(h_tilde)
h = (1 - z) * prev_h + z * h_tilde
new_state = GRUState(hidden=h)
return (h, new_state), params
[docs]
class Dropout(bx.Module):
"""Dropout layer for regularization.
During training, randomly zeros elements with probability `rate` and scales
the remaining elements by `1 / (1 - rate)` to maintain expected values.
During inference, this layer is a no-op.
Example:
dropout = Dropout(graph.child('dropout'), rate=0.5, rng=rng)
y, params = dropout(params, x, is_training=True)
"""
def __init__(
self,
graph: bx.Graph,
rate: float,
rng: bx.Rng,
) -> None:
"""Initializes the Dropout module.
Args:
graph: The graph node for this module.
rate: The probability of dropping each element (0.0 to 1.0).
rng: Rng module for generating dropout masks.
"""
super().__init__(graph)
if not 0.0 <= rate < 1.0:
raise ValueError(f'Dropout rate must be in [0.0, 1.0), got {rate}.')
self.rate = rate
self.rng = rng
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
is_training: bool = True,
) -> tuple[jax.Array, bx.Params]:
"""Applies dropout to the inputs.
Args:
params: The parameters container.
inputs: The input array.
is_training: If True, applies dropout. If False, returns inputs unchanged.
Returns:
A tuple (output, params).
"""
if not is_training or self.rate == 0.0:
return inputs, params
key, params = self.rng(params)
keep_rate = 1.0 - self.rate
mask = jax.random.bernoulli(key, keep_rate, inputs.shape)
return inputs * mask / keep_rate, params
[docs]
class LayerNorm(bx.Module):
"""Layer Normalization.
Normalizes over the last axis (features) of the input. Supports cross-device
statistics aggregation via axis_name for use with jax.shard_map.
"""
def __init__(
self,
graph: bx.Graph,
epsilon: float = 1e-5,
use_scale: bool = True,
use_bias: bool = True,
scale_init: Initializer = jax.nn.initializers.ones,
bias_init: Initializer = jax.nn.initializers.zeros,
axis_name: str | None = None,
axis_index_groups: Sequence[Sequence[int]] | None = None,
rng: bx.Rng | None = None,
) -> None:
"""Initializes the LayerNorm module.
Args:
graph: The graph node for this module.
epsilon: Small constant for numerical stability.
use_scale: Whether to use a learnable scale parameter.
use_bias: Whether to use a learnable bias parameter.
scale_init: Initializer for scale.
bias_init: Initializer for bias.
axis_name: The axis name used to combine statistics from multiple devices.
See jax.shard_map for a description of axis names.
axis_index_groups: Groups of axis indices within that named axis
representing subsets of devices to reduce over. For example,
[[0, 1], [2, 3]] would independently normalize over the first two and
last two devices. See jax.lax.psum for more details.
rng: Rng module. Required if using stochastic initializers for scale/bias.
"""
super().__init__(graph)
self.epsilon = epsilon
self.use_scale = use_scale
self.use_bias = use_bias
self.scale_init = scale_init
self.bias_init = bias_init
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups
self.rng = rng
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
) -> tuple[jax.Array, bx.Params]:
"""Applies layer normalization.
Args:
params: The parameters container.
inputs: The input array with shape [..., features].
Returns:
A tuple (normalized_output, params).
"""
features = inputs.shape[-1]
# Compute mean and variance over last axis.
mean = jnp.mean(inputs, axis=-1, keepdims=True)
var = jnp.var(inputs, axis=-1, keepdims=True)
# Cross-device aggregation if axis_name is provided.
if self.axis_name is not None:
mean = jax.lax.pmean(
mean, self.axis_name, axis_index_groups=self.axis_index_groups
)
var = jax.lax.pmean(
var, self.axis_name, axis_index_groups=self.axis_index_groups
)
# Normalize.
normalized = (inputs - mean) / jnp.sqrt(var + self.epsilon)
# Scale and shift.
if self.use_scale:
scale, params = self.get_param(
params=params,
name='scale',
shape=(features,),
init=self.scale_init,
rng=self.rng,
)
normalized = normalized * scale
if self.use_bias:
bias, params = self.get_param(
params=params,
name='bias',
shape=(features,),
init=self.bias_init,
rng=self.rng,
)
normalized = normalized + bias
return normalized, params
[docs]
class RMSNorm(bx.Module):
"""Root Mean Square Layer Normalization.
Normalizes using only the RMS of the input (no mean subtraction).
This is computationally simpler than LayerNorm. Supports cross-device
statistics aggregation via axis_name for use with jax.shard_map.
"""
def __init__(
self,
graph: bx.Graph,
epsilon: float = 1e-5,
use_scale: bool = True,
scale_init: Initializer = jax.nn.initializers.ones,
axis_name: str | None = None,
axis_index_groups: Sequence[Sequence[int]] | None = None,
rng: bx.Rng | None = None,
) -> None:
"""Initializes the RMSNorm module.
Args:
graph: The graph node for this module.
epsilon: Small constant for numerical stability.
use_scale: Whether to use a learnable scale parameter.
scale_init: Initializer for scale.
axis_name: The axis name used to combine statistics from multiple devices.
See jax.shard_map for a description of axis names.
axis_index_groups: Groups of axis indices within that named axis
representing subsets of devices to reduce over. For example,
[[0, 1], [2, 3]] would independently normalize over the first two and
last two devices. See jax.lax.psum for more details.
rng: Rng module. Required if using stochastic initializers for scale.
"""
super().__init__(graph)
self.epsilon = epsilon
self.use_scale = use_scale
self.scale_init = scale_init
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups
self.rng = rng
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
) -> tuple[jax.Array, bx.Params]:
"""Applies RMS normalization.
Args:
params: The parameters container.
inputs: The input array with shape [..., features].
Returns:
A tuple (normalized_output, params).
"""
features = inputs.shape[-1]
# Compute mean of squares.
mean_sq = jnp.mean(inputs**2, axis=-1, keepdims=True)
# Cross-device aggregation if axis_name is provided.
if self.axis_name is not None:
mean_sq = jax.lax.pmean(
mean_sq, self.axis_name, axis_index_groups=self.axis_index_groups
)
# Compute RMS.
rms = jnp.sqrt(mean_sq + self.epsilon)
# Normalize.
normalized = inputs / rms
# Scale.
if self.use_scale:
scale, params = self.get_param(
params, 'scale', (features,), self.scale_init, rng=self.rng
)
normalized = normalized * scale
return normalized, params
[docs]
class BatchNorm(bx.Module):
"""Batch Normalization.
Normalizes over the batch and spatial dimensions, maintaining running
statistics for inference. During training, computes batch statistics and
updates the running averages. During inference, uses the stored running
statistics.
The input is expected to have shape (batch, *spatial_dims, features).
Normalization is applied over all axes except the last (features) axis.
Example:
bn = BatchNorm(graph.child('bn'))
# Training: uses batch statistics and updates running stats.
y, params = bn(params, x, is_training=True)
# Inference: uses running statistics.
y, params = bn(params, x, is_training=False)
"""
def __init__(
self,
graph: bx.Graph,
momentum: float = 0.9,
epsilon: float = 1e-5,
use_scale: bool = True,
use_bias: bool = True,
scale_init: Initializer = jax.nn.initializers.ones,
bias_init: Initializer = jax.nn.initializers.zeros,
axis_name: str | None = None,
axis_index_groups: Sequence[Sequence[int]] | None = None,
rng: bx.Rng | None = None,
) -> None:
"""Initializes the BatchNorm module.
Args:
graph: The graph node for this module.
momentum: Momentum for the exponential moving average of running stats.
running_stat = momentum * running_stat + (1.0 - momentum) * batch_stat.
epsilon: Small constant for numerical stability.
use_scale: Whether to use a learnable scale parameter (gamma).
use_bias: Whether to use a learnable bias parameter (beta).
scale_init: Initializer for scale.
bias_init: Initializer for bias.
axis_name: The axis name used to combine statistics from multiple devices.
See jax.shard_map for a description of axis names.
axis_index_groups: Groups of axis indices within that named axis
representing subsets of devices to reduce over.
rng: Rng module. Required if using stochastic initializers for scale/bias.
"""
super().__init__(graph)
self.momentum = momentum
self.epsilon = epsilon
self.use_scale = use_scale
self.use_bias = use_bias
self.scale_init = scale_init
self.bias_init = bias_init
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups
self.rng = rng
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
is_training: bool = True,
) -> tuple[jax.Array, bx.Params]:
"""Applies batch normalization.
Args:
params: The parameters container.
inputs: The input array with shape (batch, *spatial_dims, features).
is_training: If True, uses batch statistics and updates running stats.
If False, uses stored running statistics (inference).
Returns:
A tuple (normalized_output, params).
"""
features = inputs.shape[-1]
# Axes to reduce over: all except the last (features) axis.
reduce_axes = tuple(range(inputs.ndim - 1))
# Get or create running statistics (non-trainable).
running_mean, params = self.get_param(
params=params,
name='running_mean',
shape=(features,),
init=jax.nn.initializers.zeros,
trainable=False,
rng=self.rng,
)
running_var, params = self.get_param(
params=params,
name='running_var',
shape=(features,),
init=jax.nn.initializers.ones,
trainable=False,
rng=self.rng,
)
if is_training:
# Compute batch statistics.
mean = jnp.mean(inputs, axis=reduce_axes)
var = jnp.var(inputs, axis=reduce_axes)
# Cross-device aggregation if axis_name is provided.
if self.axis_name is not None:
mean = jax.lax.pmean(
mean, self.axis_name, axis_index_groups=self.axis_index_groups
)
var = jax.lax.pmean(
var, self.axis_name, axis_index_groups=self.axis_index_groups
)
# Update running statistics with exponential moving average.
# stop_gradient prevents backprop through running stats.
new_running_mean = self.momentum * running_mean + (
1 - self.momentum
) * jax.lax.stop_gradient(mean)
new_running_var = self.momentum * running_var + (
1 - self.momentum
) * jax.lax.stop_gradient(var)
# Store updated running statistics.
params = self.set_param(
params=params, name='running_mean', value=new_running_mean
)
params = self.set_param(
params=params, name='running_var', value=new_running_var
)
else:
# Use running statistics for inference.
mean = running_mean
var = running_var
# Normalize.
normalized = (inputs - mean) / jnp.sqrt(var + self.epsilon)
# Scale and shift.
if self.use_scale:
scale, params = self.get_param(
params, 'scale', (features,), self.scale_init, rng=self.rng
)
normalized = normalized * scale
if self.use_bias:
bias, params = self.get_param(
params, 'bias', (features,), self.bias_init, rng=self.rng
)
normalized = normalized + bias
return normalized, params
def _normalize_tuple(x: int | Sequence[int], n: int) -> tuple[int, ...]:
"""Converts int or sequence to a tuple of length n."""
if isinstance(x, int):
return (x,) * n
return tuple(x)
[docs]
class Conv(bx.Module):
"""General N-dimensional convolution layer.
The number of spatial dimensions is inferred from kernel_size:
- 1-tuple or int: 1D convolution (single spatial dimension)
- 2-tuple: 2D convolution (height, width)
- 3-tuple: 3D convolution (depth, height, width)
Supports arbitrary batch dimensions (0 or more). Uses channels-last
convention: (*batch, *spatial_dims, channels).
Example:
# 2D convolution for images (NHWC format)
conv = Conv(
graph.child('conv'), kernel_size=(3, 3), output_channels=64, rng=rng
)
y, params = conv(params, x) # x: [batch, height, width, channels]
# 1D convolution for sequences (NLC format)
conv = Conv(graph.child('conv'), kernel_size=3, output_channels=64, rng=rng)
y, params = conv(params, x) # x: [batch, length, channels]
# Unbatched input.
conv = Conv(
graph.child('conv'), kernel_size=(3, 3), output_channels=64, rng=rng
)
y, params = conv(params, x) # x: [height, width, channels]
"""
def __init__(
self,
graph: bx.Graph,
kernel_size: int | Sequence[int],
output_channels: int,
rng: bx.Rng | None,
strides: int | Sequence[int] = 1,
padding: PaddingLike = 'SAME',
input_dilation: int | Sequence[int] = 1,
kernel_dilation: int | Sequence[int] = 1,
feature_group_count: int = 1,
use_bias: bool = True,
kernel_init: Initializer = jax.nn.initializers.lecun_normal(),
bias_init: Initializer = jax.nn.initializers.zeros,
kernel_metadata: dict[str, Any] | None = None,
bias_metadata: dict[str, Any] | None = None,
) -> None:
"""Initializes the Conv module.
Args:
graph: The graph node for this module.
kernel_size: Shape of the convolutional kernel as a tuple, determining
the number of spatial dimensions (e.g., (3, 3) for 2D conv). For 1D
convolution, either an int or a 1-tuple can be used.
output_channels: Number of output channels.
rng: Rng module for random initialization. If kernel_init and bias_init
are constant, Rng is not required (but still recommended).
strides: Stride of the convolution. An int is broadcast to all spatial
dimensions.
padding: Padding mode. Either 'SAME', 'VALID', or a sequence of
(low, high) padding pairs for each spatial dimension.
input_dilation: Dilation of the input (transposed convolution).
kernel_dilation: Dilation of the kernel (atrous convolution).
feature_group_count: Number of feature groups for grouped convolution.
Set to input_channels for depthwise convolution.
use_bias: Whether to add a learnable bias.
kernel_init: Initializer for the kernel.
bias_init: Initializer for the bias.
kernel_metadata: Optional metadata for the kernel parameter.
bias_metadata: Optional metadata for the bias parameter.
"""
super().__init__(graph)
self.kernel_size = (
(kernel_size,) if isinstance(kernel_size, int) else tuple(kernel_size)
)
self.output_channels = output_channels
self.rng = rng
self.strides = strides
self.padding = padding
self.input_dilation = input_dilation
self.kernel_dilation = kernel_dilation
self.feature_group_count = feature_group_count
self.use_bias = use_bias
self.kernel_init = kernel_init
self.bias_init = bias_init
self.kernel_metadata = kernel_metadata
self.bias_metadata = bias_metadata
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
precision: jax.lax.Precision | None = None,
) -> tuple[jax.Array, bx.Params]:
"""Applies the convolution.
Args:
params: The parameters container.
inputs: Input array with shape (*batch, *spatial_dims, input_channels).
Supports arbitrary batch dimensions (0 or more).
precision: Optional precision for the convolution.
Returns:
A tuple (output, params), where output has shape
(*batch, *out_spatial, output_channels).
Raises:
ValueError: If input has fewer dimensions than needed for conv.
"""
num_spatial = len(self.kernel_size)
min_rank = num_spatial + 1 # spatial + channels (batch is optional)
if inputs.ndim < min_rank:
raise ValueError(
f'Expected input rank >= {min_rank} for {num_spatial}D conv '
f'(at least {num_spatial} spatial dims + 1 channel dim), '
f'got {inputs.ndim}.'
)
# Handle arbitrary batch dimensions by reshaping.
# (*batch, *spatial, channels) -> (combined_batch, *spatial, channels)
batch_shape = inputs.shape[: -num_spatial - 1]
spatial_shape = inputs.shape[-num_spatial - 1 : -1]
input_channels = inputs.shape[-1]
batch_size = math.prod(batch_shape) if batch_shape else 1
inputs_flat = inputs.reshape(
(batch_size,) + spatial_shape + (input_channels,)
)
if input_channels % self.feature_group_count != 0:
raise ValueError(
f'input_channels ({input_channels}) must be divisible by '
f'feature_group_count ({self.feature_group_count}).'
)
# Kernel shape: (*kernel_size, input_channels // groups, output_channels)
kernel_shape = self.kernel_size + (
input_channels // self.feature_group_count,
self.output_channels,
)
kernel, params = self.get_param(
params,
'kernel',
kernel_shape,
self.kernel_init,
metadata=self.kernel_metadata,
rng=self.rng,
)
# Normalize strides and dilations to tuples.
strides = _normalize_tuple(self.strides, num_spatial)
input_dilation = _normalize_tuple(self.input_dilation, num_spatial)
kernel_dilation = _normalize_tuple(self.kernel_dilation, num_spatial)
# Build dimension numbers for channels-last format.
# (N, *spatial, C) -> lax expects (N, C, *spatial)
# We use dimension_numbers to avoid transposing.
spatial_dims = tuple(range(1, num_spatial + 1))
lhs_spec = (0, num_spatial + 1) + spatial_dims # (N, C, *spatial)
rhs_spec = (num_spatial + 1, num_spatial) + tuple(range(num_spatial))
out_spec = (0, num_spatial + 1) + spatial_dims
dimension_numbers = jax.lax.ConvDimensionNumbers(
lhs_spec=lhs_spec, rhs_spec=rhs_spec, out_spec=out_spec
)
# Apply convolution.
outputs = jax.lax.conv_general_dilated(
inputs_flat,
kernel,
window_strides=strides,
padding=self.padding,
lhs_dilation=input_dilation,
rhs_dilation=kernel_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=self.feature_group_count,
precision=precision,
)
# Add bias.
if self.use_bias:
bias, params = self.get_param(
params,
'bias',
(self.output_channels,),
self.bias_init,
metadata=self.bias_metadata,
rng=self.rng,
)
outputs = outputs + bias
# Reshape output back to original batch shape.
# (combined_batch, *out_spatial, out_channels)
# -> (*batch, *out_spatial, out_channels)
out_spatial = outputs.shape[1:-1]
outputs = outputs.reshape(
batch_shape + out_spatial + (self.output_channels,)
)
return outputs, params
[docs]
class ConvTranspose(bx.Module):
"""General N-dimensional transposed convolution layer.
Also known as deconvolution or fractionally-strided convolution.
The number of spatial dimensions is inferred from kernel_size:
- 1-tuple or int: 1D convolution (single spatial dimension)
- 2-tuple: 2D convolution (height, width)
- 3-tuple: 3D convolution (depth, height, width)
Supports arbitrary batch dimensions (0 or more). Uses channels-last
convention: (*batch, *spatial_dims, channels).
Example:
# 2D transposed convolution for images (NHWC format).
conv_t = ConvTranspose(
graph.child('conv_t'), kernel_size=(3, 3), output_channels=3, rng=rng
)
y, params = conv_t(params, x) # x: [batch, height, width, channels]
"""
def __init__(
self,
graph: bx.Graph,
kernel_size: int | Sequence[int],
output_channels: int,
rng: bx.Rng | None,
strides: int | Sequence[int] = 1,
padding: PaddingLike = 'SAME',
kernel_dilation: int | Sequence[int] = 1,
feature_group_count: int = 1,
use_bias: bool = True,
kernel_init: Initializer = jax.nn.initializers.lecun_normal(),
bias_init: Initializer = jax.nn.initializers.zeros,
kernel_metadata: dict[str, Any] | None = None,
bias_metadata: dict[str, Any] | None = None,
) -> None:
"""Initializes the ConvTranspose module.
Args:
graph: The graph node for this module.
kernel_size: Shape of the convolutional kernel as a tuple, determining
the number of spatial dimensions (e.g., (3, 3) for 2D conv). For 1D
convolution, either an int or a 1-tuple can be used.
output_channels: Number of output channels.
rng: Rng module for random initialization. If kernel_init and bias_init
are constant, Rng is not required (but still recommended).
strides: Stride of the convolution. An int is broadcast to all spatial
dimensions.
padding: Padding mode. Either 'SAME', 'VALID', or a sequence of
(low, high) padding pairs for each spatial dimension.
kernel_dilation: Dilation of the kernel (atrous convolution).
feature_group_count: Number of feature groups for grouped convolution.
Set to input_channels for depthwise convolution.
use_bias: Whether to add a learnable bias.
kernel_init: Initializer for the kernel.
bias_init: Initializer for the bias.
kernel_metadata: Optional metadata for the kernel parameter.
bias_metadata: Optional metadata for the bias parameter.
"""
super().__init__(graph)
self.kernel_size = (
(kernel_size,) if isinstance(kernel_size, int) else tuple(kernel_size)
)
self.output_channels = output_channels
self.rng = rng
self.strides = strides
self.padding = padding
self.kernel_dilation = kernel_dilation
self.feature_group_count = feature_group_count
self.use_bias = use_bias
self.kernel_init = kernel_init
self.bias_init = bias_init
self.kernel_metadata = kernel_metadata
self.bias_metadata = bias_metadata
def __call__(
self,
params: bx.Params,
inputs: jax.Array,
precision: jax.lax.Precision | None = None,
) -> tuple[jax.Array, bx.Params]:
"""Applies the transposed convolution.
Args:
params: The parameters container.
inputs: Input array with shape (*batch, *spatial_dims, input_channels).
Supports arbitrary batch dimensions (0 or more).
precision: Optional precision for the convolution.
Returns:
A tuple (output, params), where output has shape
(*batch, *out_spatial, output_channels).
Raises:
ValueError: If input has fewer dimensions than needed for conv.
"""
num_spatial = len(self.kernel_size)
min_rank = num_spatial + 1 # spatial + channels (batch is optional)
if inputs.ndim < min_rank:
raise ValueError(
f'Expected input rank >= {min_rank} for {num_spatial}D conv_transpose'
f' (at least {num_spatial} spatial dims + 1 channel dim), '
f'got {inputs.ndim}.'
)
# Handle arbitrary batch dimensions by reshaping.
# (*batch, *spatial, channels) -> (combined_batch, *spatial, channels)
batch_shape = inputs.shape[: -num_spatial - 1]
spatial_shape = inputs.shape[-num_spatial - 1 : -1]
input_channels = inputs.shape[-1]
batch_size = math.prod(batch_shape) if batch_shape else 1
inputs_flat = inputs.reshape(
(batch_size,) + spatial_shape + (input_channels,)
)
if input_channels % self.feature_group_count != 0:
raise ValueError(
f'input_channels ({input_channels}) must be divisible by '
f'feature_group_count ({self.feature_group_count}).'
)
# Kernel shape for ConvTranspose:
# (*kernel_size, output_channels, input_channels // groups)
kernel_shape = self.kernel_size + (
self.output_channels,
input_channels // self.feature_group_count,
)
kernel, params = self.get_param(
params,
'kernel',
kernel_shape,
self.kernel_init,
metadata=self.kernel_metadata,
rng=self.rng,
)
# Normalize strides and dilations to tuples.
strides = _normalize_tuple(self.strides, num_spatial)
kernel_dilation = _normalize_tuple(self.kernel_dilation, num_spatial)
# Build dimension numbers for channels-last format.
# Input: (N, *spatial, C_in) -> lax expects (N, C_in, *spatial)
# Kernel: (*kernel_spatial, C_out, C_in_per_group)
# Output: (N, *spatial, C_out) -> lax expects (N, C_out, *spatial)
spatial_dims = tuple(range(1, num_spatial + 1))
lhs_spec = (0, num_spatial + 1) + spatial_dims # (N, C_in, *spatial)
# For conv_transpose, rhs_spec maps kernel to (out_c, in_c, *spatial)
# We want O -> C_out (index num_spatial)
# We want I -> C_in (index num_spatial + 1)
rhs_spec = (num_spatial, num_spatial + 1) + tuple(range(num_spatial))
out_spec = (0, num_spatial + 1) + spatial_dims # (N, C_out, *spatial)
dimension_numbers = jax.lax.ConvDimensionNumbers(
lhs_spec=lhs_spec, rhs_spec=rhs_spec, out_spec=out_spec
)
# Apply transposed convolution.
outputs = jax.lax.conv_transpose(
inputs_flat,
kernel,
strides=strides,
padding=self.padding,
rhs_dilation=kernel_dilation,
dimension_numbers=dimension_numbers,
precision=precision,
)
# Add bias.
if self.use_bias:
bias, params = self.get_param(
params,
'bias',
(self.output_channels,),
self.bias_init,
metadata=self.bias_metadata,
rng=self.rng,
)
outputs = outputs + bias
# Reshape output back to original batch shape.
# (combined_batch, *out_spatial, out_channels)
# -> (*batch, *out_spatial, out_channels)
out_spatial = outputs.shape[1:-1]
outputs = outputs.reshape(
batch_shape + out_spatial + (self.output_channels,)
)
return outputs, params
[docs]
def max_pool(
inputs: jax.Array,
window_shape: int | Sequence[int],
strides: int | Sequence[int] | None = None,
padding: PaddingLike = 'VALID',
) -> jax.Array:
"""Applies max pooling over spatial dimensions.
The number of spatial dimensions is inferred from window_shape:
- 1-tuple or int: 1D pooling (single spatial dimension)
- 2-tuple: 2D pooling (height, width)
- 3-tuple: 3D pooling (depth, height, width)
Supports arbitrary batch dimensions (0 or more). Uses channels-last
convention: (*batch, *spatial_dims, channels).
Args:
inputs: Input array with shape (*batch, *spatial_dims, channels).
window_shape: Shape of the pooling window as a tuple, determining the
number of spatial dimensions. For 1D pooling, an int can be used.
strides: Stride of the pooling. If None, uses window_shape (no overlap).
padding: Padding mode. Either 'SAME', 'VALID', or a sequence of
(low, high) padding pairs for each spatial dimension.
Returns:
Pooled output array.
Example:
# 2x2 max pooling with stride 2
y = max_pool(x, window_shape=(2, 2), strides=2)
"""
window = (
(window_shape,) if isinstance(window_shape, int) else tuple(window_shape)
)
num_spatial = len(window)
min_rank = num_spatial + 1 # spatial + channels (batch optional)
if inputs.ndim < min_rank:
raise ValueError(
f'Expected input rank >= {min_rank} for {num_spatial}D pooling '
f'(at least {num_spatial} spatial dims + 1 channel dim), '
f'got {inputs.ndim}.'
)
# Handle arbitrary batch dimensions by reshaping.
batch_shape = inputs.shape[: -num_spatial - 1]
spatial_shape = inputs.shape[-num_spatial - 1 : -1]
channels = inputs.shape[-1]
batch_size = math.prod(batch_shape) if batch_shape else 1
inputs_flat = inputs.reshape((batch_size,) + spatial_shape + (channels,))
strides_tuple = (
window if strides is None else _normalize_tuple(strides, num_spatial)
)
# jax.lax.reduce_window expects window and strides for all dims.
full_window = (1,) + window + (1,)
full_strides = (1,) + strides_tuple + (1,)
# Normalize padding for all dimensions (including batch and channel).
if isinstance(padding, str):
full_padding = padding
else:
full_padding = ((0, 0),) + tuple(padding) + ((0, 0),)
outputs = jax.lax.reduce_window(
inputs_flat,
init_value=-jnp.inf,
computation=jax.lax.max,
window_dimensions=full_window,
window_strides=full_strides,
padding=full_padding,
)
# Reshape output back to original batch shape.
out_spatial = outputs.shape[1:-1]
return outputs.reshape(batch_shape + out_spatial + (channels,))
[docs]
def min_pool(
inputs: jax.Array,
window_shape: int | Sequence[int],
strides: int | Sequence[int] | None = None,
padding: PaddingLike = 'VALID',
) -> jax.Array:
"""Applies min pooling over spatial dimensions.
The number of spatial dimensions is inferred from window_shape:
- 1-tuple or int: 1D pooling (single spatial dimension)
- 2-tuple: 2D pooling (height, width)
- 3-tuple: 3D pooling (depth, height, width)
Supports arbitrary batch dimensions (0 or more). Uses channels-last
convention: (*batch, *spatial_dims, channels).
Args:
inputs: Input array with shape (*batch, *spatial_dims, channels).
window_shape: Shape of the pooling window as a tuple, determining the
number of spatial dimensions. For 1D pooling, an int can be used.
strides: Stride of the pooling. If None, uses window_shape (no overlap).
padding: Padding mode. Either 'SAME', 'VALID', or a sequence of
(low, high) padding pairs for each spatial dimension.
Returns:
Pooled output array.
"""
window = (
(window_shape,) if isinstance(window_shape, int) else tuple(window_shape)
)
num_spatial = len(window)
min_rank = num_spatial + 1
if inputs.ndim < min_rank:
raise ValueError(
f'Expected input rank >= {min_rank} for {num_spatial}D pooling '
f'(at least {num_spatial} spatial dims + 1 channel dim), '
f'got {inputs.ndim}.'
)
batch_shape = inputs.shape[: -num_spatial - 1]
spatial_shape = inputs.shape[-num_spatial - 1 : -1]
channels = inputs.shape[-1]
batch_size = math.prod(batch_shape) if batch_shape else 1
inputs_flat = inputs.reshape((batch_size,) + spatial_shape + (channels,))
strides_tuple = (
window if strides is None else _normalize_tuple(strides, num_spatial)
)
full_window = (1,) + window + (1,)
full_strides = (1,) + strides_tuple + (1,)
if isinstance(padding, str):
full_padding = padding
else:
full_padding = ((0, 0),) + tuple(padding) + ((0, 0),)
outputs = jax.lax.reduce_window(
inputs_flat,
init_value=jnp.inf,
computation=jax.lax.min,
window_dimensions=full_window,
window_strides=full_strides,
padding=full_padding,
)
out_spatial = outputs.shape[1:-1]
return outputs.reshape(batch_shape + out_spatial + (channels,))
[docs]
def avg_pool(
inputs: jax.Array,
window_shape: int | Sequence[int],
strides: int | Sequence[int] | None = None,
padding: PaddingLike = 'VALID',
) -> jax.Array:
"""Applies average pooling over spatial dimensions.
The number of spatial dimensions is inferred from window_shape:
- 1-tuple or int: 1D pooling (single spatial dimension)
- 2-tuple: 2D pooling (height, width)
- 3-tuple: 3D pooling (depth, height, width)
Supports arbitrary batch dimensions (0 or more). Uses channels-last
convention: (*batch, *spatial_dims, channels).
Note: When padding='SAME', the average is computed over the valid (non-padded)
pixels in the window, ignoring the zeros added by padding.
Args:
inputs: Input array with shape (*batch, *spatial_dims, channels).
window_shape: Shape of the pooling window as a tuple, determining the
number of spatial dimensions. For 1D pooling, an int can be used.
strides: Stride of the pooling. If None, uses window_shape (no overlap).
padding: Padding mode. Either 'SAME', 'VALID', or a sequence of
(low, high) padding pairs for each spatial dimension.
Returns:
Pooled output array.
Example:
# 2x2 average pooling with stride 2
y = avg_pool(x, window_shape=(2, 2), strides=2)
"""
window = (
(window_shape,) if isinstance(window_shape, int) else tuple(window_shape)
)
num_spatial = len(window)
min_rank = num_spatial + 1
if inputs.ndim < min_rank:
raise ValueError(
f'Expected input rank >= {min_rank} for {num_spatial}D pooling '
f'(at least {num_spatial} spatial dims + 1 channel dim), '
f'got {inputs.ndim}.'
)
batch_shape = inputs.shape[: -num_spatial - 1]
spatial_shape = inputs.shape[-num_spatial - 1 : -1]
channels = inputs.shape[-1]
batch_size = math.prod(batch_shape) if batch_shape else 1
inputs_flat = inputs.reshape((batch_size,) + spatial_shape + (channels,))
strides_tuple = (
window if strides is None else _normalize_tuple(strides, num_spatial)
)
full_window = (1,) + window + (1,)
full_strides = (1,) + strides_tuple + (1,)
if isinstance(padding, str):
full_padding = padding
else:
full_padding = ((0, 0),) + tuple(padding) + ((0, 0),)
# Sum pooling.
pooled_sum = jax.lax.reduce_window(
inputs_flat,
init_value=0.0,
computation=jax.lax.add,
window_dimensions=full_window,
window_strides=full_strides,
padding=full_padding,
)
# For 'VALID' padding, all windows are full so we divide by a constant.
# For 'SAME' or explicit padding, we count valid elements per window.
if padding == 'VALID':
window_size = math.prod(window)
outputs = pooled_sum / window_size
else:
# Count valid elements in each window. Crucial for 'SAME' padding where
# boundary windows may have fewer valid elements.
mask = jnp.ones_like(inputs_flat)
window_counts = jax.lax.reduce_window(
mask,
init_value=0.0,
computation=jax.lax.add,
window_dimensions=full_window,
window_strides=full_strides,
padding=full_padding,
)
outputs = pooled_sum / window_counts
out_spatial = outputs.shape[1:-1]
return outputs.reshape(batch_shape + out_spatial + (channels,))