API Reference¶
Core Interfaces¶
- class blox.Graph(name: str)[source]¶
Bases:
objectThe structural graph of a model.
A Graph represents the hierarchical structure of your model. Each node in the graph corresponds to a module (layer), and edges represent the parent-child relationships between them. When you create a child node with
graph.child(), you’re extending this structure.The graph serves two purposes:
It defines how your model is organized, which modules contain which.
It provides unique namespaces for parameters. Each node’s path (e.g.
('net', 'encoder', 'dense')) becomes the prefix for that module’s params.
Dependency injection creates additional relationships in the graph. When a module is created externally and passed into another, it retains its original position in the graph (as a sibling rather than a child), enabling flexible parameter sharing patterns.
The graph does not store parameters; that’s the job of the Params container. Graph defines structure; Params holds state.
- child(name: str) Graph[source]¶
Creates or retrieves a child node in the graph hierarchy.
- Parameters:
name – The name of the child node.
- Returns:
A new Graph instance representing the child.
- Raises:
ValueError – If a child with the same name already exists.
- walk() Iterator[tuple[tuple[str, ...], Module]][source]¶
Recursively yields (path, module) for all descendant modules.
Iterates depth-first through all children, yielding only nodes that have a bound module. Does not include self.
- Yields:
Tuples of (path, module) for each descendant with a bound module.
Example:
for path, module in graph.walk(): if isinstance(module, bx.Linear): print(f'Found Linear at {path}')
- class blox.Module(graph: Graph)[source]¶
Bases:
objectBase class for neural network layers.
Module provides the foundation for building neural network layers in blox. It connects layers to the Graph hierarchy for parameter namespacing and provides helper methods for parameter creation.
Key features:
Graph binding: each module owns a Graph node that namespaces its params.
Constructor capture: arguments are automatically saved to graph metadata for visualization and serialization.
Parameter helpers:
get_paramandset_paramsimplify parameter handling.
All subclasses must:
Accept
graphas the first constructor argument.Call
super().__init__(graph)in their__init__.Implement
__call__(self, params, ...) -> (output, params).
Example:
class Linear(bx.Module): def __init__(self, graph, output_size): super().__init__(graph) self.output_size = output_size def __call__(self, params, x): kernel, params = self.get_param( params, 'kernel', (x.shape[-1], self.output_size), jax.nn.initializers.lecun_normal(), ) return x @ kernel, params graph = bx.Graph('net') linear = Linear(graph.child('linear'), output_size=32)
- get_param(params: ~blox.interfaces.Params, name: str, shape: tuple[int, ...] | None = None, init: ~jax.nn.initializers.Initializer | None = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, trainable: bool = True, metadata: dict[str, ~typing.Any] | None = None, rng: ~blox.interfaces.Rng | None = None) tuple[Array, Params][source]¶
Gets or creates a parameter in this module’s namespace.
On first call, creates a new parameter using the initializer. On subsequent calls, returns the existing parameter value.
For existing parameters,
shapeandinitcan be omitted:kernel, params = self.get_param(params, 'kernel')
- Parameters:
params – The parameter container.
name – Local parameter name (e.g.,
'kernel','bias').shape – Shape of the parameter tensor. Required for new params.
init – JAX initializer function. Required for new params.
dtype – Data type (default:
float32).trainable – Whether gradients should be computed (default: True).
metadata – Optional metadata dict. Common keys include
'sharding': a tuple of mesh axis names for model parallelism.rng – Optional Rng module for stochastic initialization.
- Returns:
Tuple of
(parameter_value, updated_params).- Raises:
KeyError – If param doesn’t exist and shape/init not provided.
Example:
# Creating a new parameter: kernel, params = self.get_param( params, 'kernel', shape=(in_size, out_size), init=jax.nn.initializers.lecun_normal(), rng=self.rng, metadata={'sharding': (None, 'model')}, ) # Getting an existing parameter: kernel, params = self.get_param(params, 'kernel')
- param_path(name: str) tuple[str, ...][source]¶
Returns the full path for a parameter in this module’s namespace.
- Parameters:
name – Local parameter name.
- Returns:
Full tuple path like
('net', 'linear', 'kernel').
Example:
# Check if a param exists. if module.param_path('kernel') in params: kernel, params = module.get_param(...)
- set_param(params: Params, name: str, value: Array | None, trainable: bool | None = None, metadata: dict[str, Any] | None = None) Params[source]¶
Updates a parameter in this module’s namespace.
- Parameters:
params – The parameter container.
name – Local parameter name.
value – New value for the parameter, or None to keep existing.
trainable – New trainable flag, or None to keep existing.
metadata – Metadata to merge with existing, or None to keep existing.
- Returns:
Updated Params container.
- Raises:
ValueError – If value is None and neither trainable nor metadata provided.
Example:
# Update just the value. params = module.set_param(params, 'kernel', new_kernel) # Freeze a parameter. params = module.set_param(params, 'kernel', None, trainable=False) # Add metadata. params = module.set_param( params, 'kernel', None, metadata={'tag': 'lora'} )
- class blox.Params[source]¶
Bases:
objectImmutable container for model parameters and state.
Params is a pure state store holding all model state: trainable weights, non-trainable values (like batch norm statistics), and RNG state. It enforces functional purity by returning new instances on every modification.
Key features:
Functional updates: all methods return new Params instances.
Tuple paths: parameters are keyed by tuples like
('net', 'linear', 'w').Trainable split: use
split()to separate trainable from non-trainable.
Example:
graph = bx.Graph('net') rng = bx.Rng(graph.child('rng')) model = MyModel(graph.child('model'), rng=rng) # Create params and seed the Rng. params = rng.seed(bx.Params(), seed=42) # Forward pass creates parameters. _, params = model(params, x) params = params.locked() # Training loop. trainable, non_trainable = params.split() grads = jax.grad(loss_fn)(trainable, non_trainable, x) trainable = jax.tree.map(lambda w, g: w - lr * g, trainable, grads) params = trainable.merge(non_trainable)
- locked() Params[source]¶
Returns locked params that prevent new parameter creation.
After locking, attempting to create new parameters via get_param will raise KeyError. This catches bugs where parameter names change between training runs.
- merge(other: Params) Params[source]¶
Combines this container with another.
Parameters from other override those in self if paths conflict. Both containers must have the same locked state.
- Parameters:
other – Another Params container to merge in.
- Returns:
A new merged Params container.
- Raises:
ValueError – If locked state doesn’t match.
- split(predicate: Callable[[tuple[str, ...], Param], bool] | None = None) tuple[Params, Params][source]¶
Partitions parameters into two containers.
Without arguments, splits into trainable and non-trainable parameters. This is the standard pattern for computing gradients:
trainable, non_trainable = params.split() grads = jax.grad(loss_fn)(trainable, non_trainable, x)
- Parameters:
predicate – Optional function (path, param) -> bool. Parameters where the predicate returns True go in the first container. Defaults to splitting by trainable flag.
- Returns:
Tuple of (matching_params, non_matching_params).
- tree_flatten() tuple[tuple[dict[tuple[str, ...], Param]], tuple[bool]][source]¶
Flattens for JAX pytree operations.
- classmethod tree_unflatten(aux: tuple[bool], children: tuple[dict[tuple[str, ...], Param]]) Params[source]¶
Unflattens from JAX pytree operations.
- unlocked() Params[source]¶
Returns unlocked params that allow new parameter creation.
Use this when you need to add parameters after initial locking, such as adding LoRA adapters to a pretrained model.
Example:
# Load pretrained model. params = load_pretrained() params = params.locked() # Later, add LoRA. params = params.unlocked() apply_lora(model) _, params = model(params, dummy_input) # Initialize LoRA params. params = params.locked()
- values() ValuesView[Param][source]¶
Returns all Param wrappers.
- class blox.Param(value: Any, trainable: bool = True, metadata: dict[str, Any] | None = None)[source]¶
Bases:
objectA wrapper around a parameter value that holds metadata.
- value¶
The actual JAX array or PyTree stored.
- trainable¶
Boolean flag indicating if gradients should be computed.
- metadata¶
Dictionary for arbitrary tags. Common keys include
'sharding'(a tuple of axis names like(None, 'model')for partitioning) and'tag'(a string identifier like'rng'or'optimizer_state').
- replace(**updates: Any) Param[source]¶
Creates a new Param with updated fields.
- Parameters:
**updates – Keyword arguments matching the attribute names to update.
- Returns:
A new Param instance.
- class blox.Rng(graph: Graph)[source]¶
Bases:
ModuleA random number generator stream stored as non-trainable params.
Produces deterministic, counter-based random keys. The seed is stored in Params, not the Rng, which allows the same Rng module to be used with different seeds without changing the model structure.
Seeds defined as int are converted and stored as a JAX key array.
Example:
graph = bx.Graph('net') rng = bx.Rng(graph.child('rng')) model = MyModel(graph.child('model'), rng=rng) # Create params and seed the Rng. params = rng.seed(bx.Params(), seed=42) # Forward pass creates parameters. _, params = model(params, x) params = params.locked()
Modules that need randomness should accept an Rng on construction:
class Dropout(bx.Module): def __init__(self, graph, rate, rng): super().__init__(graph) self.rate = rate self.rng = rng def __call__(self, params, x, is_training=True): if not is_training: return x, params key, params = self.rng(params) return jax.random.dropout(key, self.rate, x), params
- get_counter(params: Params) Array[source]¶
Returns the counter value.
- Parameters:
params – The params container.
- Raises:
KeyError – If this Rng is not initialized.
- get_seed(params: Params) Array[source]¶
Returns the seed key.
The seed is stored internally as a JAX key array.
- Parameters:
params – The params container.
- Raises:
KeyError – If this Rng is not initialized.
- seed(params: Params, *, seed: int | Array | None = None, counter: int | Array | None = None) Params[source]¶
Sets the seed and/or counter for this Rng.
If this Rng is not yet initialized, creates the params with the given seed (required) and counter (defaults to 0). If already initialized, updates the specified values.
- Parameters:
params – The params container.
seed – Seed value (int or JAX key). Required if not initialized.
counter – Counter value. Defaults to 0 if initializing, unchanged when None while updating.
- Returns:
Updated params.
- Raises:
ValueError – If not initialized and seed is None, or if initialized and both seed and counter are None.
Layers¶
- class blox.Embed(graph: ~blox.interfaces.Graph, num_embeddings: int, embedding_size: int, rng: ~blox.interfaces.Rng | None, embedding_init: ~jax.nn.initializers.Initializer = <function variance_scaling.<locals>.init>, embedding_metadata: dict[str, ~typing.Any] | None = None)[source]¶
Bases:
ModuleEmbedding layer that maps integer indices to dense vectors.
Supports weight tying for language models via the
attendmethod, which applies the transpose of the embedding matrix (useful for output projections).Example:
embed = Embed( graph.child('embed'), num_embeddings=10000, embedding_size=512, rng=rng, ) # Forward pass: indices -> embeddings. embeddings, params = embed(params, token_ids) # Weight-tied output projection: hidden -> logits. logits, params = embed.attend(params, hidden_states)
- attend(params: Params, inputs: Array) tuple[Array, Params][source]¶
Applies the transpose of the embedding matrix (for weight tying).
This is commonly used in language models where the output projection shares weights with the input embedding.
- Parameters:
params – The parameters container.
inputs – Input array of shape […, embedding_size].
- Returns:
A tuple (logits, params). Logits have shape […, num_embeddings].
- class blox.Linear(graph: ~blox.interfaces.Graph, output_size: int, rng: ~blox.interfaces.Rng | None, use_bias: bool = True, kernel_init: ~jax.nn.initializers.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~jax.nn.initializers.Initializer = <function zeros>, kernel_metadata: dict[str, ~typing.Any] | None = None, bias_metadata: dict[str, ~typing.Any] | None = None)[source]¶
Bases:
ModuleA standard linear transformation layer.
Computes
output = input @ kernel + bias.Supports model parallelism via metadata. Example for sharding weights:
linear = Linear( graph.child('linear'), output_size=1024, rng=rng, kernel_metadata={'sharding': (None, 'model')}, # Shard output. bias_metadata={'sharding': ('model',)}, )
- class blox.Sequential(graph: Graph, layers: Sequence[Module | Callable[[Array], Array]])[source]¶
Bases:
ModuleA sequential container.
Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.
Example:
mlp = Sequential( graph.child('mlp'), [ bx.Linear(graph.child('l1'), 32), jax.nn.relu, bx.Linear(graph.child('l2'), 10), ], ) y, params = mlp(params, x)
- class blox.Conv(graph: ~blox.interfaces.Graph, kernel_size: int | ~typing.Sequence[int], output_channels: int, rng: ~blox.interfaces.Rng | None, strides: int | ~typing.Sequence[int] = 1, padding: str | ~typing.Sequence[tuple[int, int]] = 'SAME', input_dilation: int | ~typing.Sequence[int] = 1, kernel_dilation: int | ~typing.Sequence[int] = 1, feature_group_count: int = 1, use_bias: bool = True, kernel_init: ~jax.nn.initializers.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~jax.nn.initializers.Initializer = <function zeros>, kernel_metadata: dict[str, ~typing.Any] | None = None, bias_metadata: dict[str, ~typing.Any] | None = None)[source]¶
Bases:
ModuleGeneral N-dimensional convolution layer.
The number of spatial dimensions is inferred from
kernel_size:1-tuple or int: 1D convolution (single spatial dimension).
2-tuple: 2D convolution (height, width).
3-tuple: 3D convolution (depth, height, width).
Supports arbitrary batch dimensions (0 or more). Uses channels-last convention:
(*batch, *spatial_dims, channels).Example:
# 2D convolution for images (NHWC format). conv = Conv( graph.child('conv'), kernel_size=(3, 3), output_channels=64, rng=rng, ) y, params = conv(params, x) # x: [batch, height, width, channels] # 1D convolution for sequences (NLC format). conv = Conv( graph.child('conv'), kernel_size=3, output_channels=64, rng=rng ) y, params = conv(params, x) # x: [batch, length, channels] # Unbatched input. conv = Conv( graph.child('conv'), kernel_size=(3, 3), output_channels=64, rng=rng, ) y, params = conv(params, x) # x: [height, width, channels]
- class blox.ConvTranspose(graph: ~blox.interfaces.Graph, kernel_size: int | ~typing.Sequence[int], output_channels: int, rng: ~blox.interfaces.Rng | None, strides: int | ~typing.Sequence[int] = 1, padding: str | ~typing.Sequence[tuple[int, int]] = 'SAME', kernel_dilation: int | ~typing.Sequence[int] = 1, feature_group_count: int = 1, use_bias: bool = True, kernel_init: ~jax.nn.initializers.Initializer = <function variance_scaling.<locals>.init>, bias_init: ~jax.nn.initializers.Initializer = <function zeros>, kernel_metadata: dict[str, ~typing.Any] | None = None, bias_metadata: dict[str, ~typing.Any] | None = None)[source]¶
Bases:
ModuleGeneral N-dimensional transposed convolution layer.
Also known as deconvolution or fractionally-strided convolution. The number of spatial dimensions is inferred from
kernel_size:1-tuple or int: 1D convolution (single spatial dimension).
2-tuple: 2D convolution (height, width).
3-tuple: 3D convolution (depth, height, width).
Supports arbitrary batch dimensions (0 or more). Uses channels-last convention:
(*batch, *spatial_dims, channels).Example:
# 2D transposed convolution for images (NHWC format). conv_t = ConvTranspose( graph.child('conv_t'), kernel_size=(3, 3), output_channels=3, rng=rng, ) y, params = conv_t(params, x) # x: [batch, height, width, channels]
- class blox.Dropout(graph: Graph, rate: float, rng: Rng)[source]¶
Bases:
ModuleDropout layer for regularization.
During training, randomly zeros elements with probability rate and scales the remaining elements by 1 / (1 - rate) to maintain expected values. During inference, this layer is a no-op.
Example:
dropout = Dropout(graph.child('dropout'), rate=0.5, rng=rng) y, params = dropout(params, x, is_training=True)
- class blox.LayerNorm(graph: ~blox.interfaces.Graph, epsilon: float = 1e-05, use_scale: bool = True, use_bias: bool = True, scale_init: ~jax.nn.initializers.Initializer = <function ones>, bias_init: ~jax.nn.initializers.Initializer = <function zeros>, axis_name: str | None = None, axis_index_groups: ~typing.Sequence[~typing.Sequence[int]] | None = None, rng: ~blox.interfaces.Rng | None = None)[source]¶
Bases:
ModuleLayer Normalization.
Normalizes over the last axis (features) of the input. Supports cross-device statistics aggregation via axis_name for use with jax.shard_map.
- class blox.RMSNorm(graph: ~blox.interfaces.Graph, epsilon: float = 1e-05, use_scale: bool = True, scale_init: ~jax.nn.initializers.Initializer = <function ones>, axis_name: str | None = None, axis_index_groups: ~typing.Sequence[~typing.Sequence[int]] | None = None, rng: ~blox.interfaces.Rng | None = None)[source]¶
Bases:
ModuleRoot Mean Square Layer Normalization.
Normalizes using only the RMS of the input (no mean subtraction). This is computationally simpler than LayerNorm. Supports cross-device statistics aggregation via axis_name for use with jax.shard_map.
- class blox.BatchNorm(graph: ~blox.interfaces.Graph, momentum: float = 0.9, epsilon: float = 1e-05, use_scale: bool = True, use_bias: bool = True, scale_init: ~jax.nn.initializers.Initializer = <function ones>, bias_init: ~jax.nn.initializers.Initializer = <function zeros>, axis_name: str | None = None, axis_index_groups: ~typing.Sequence[~typing.Sequence[int]] | None = None, rng: ~blox.interfaces.Rng | None = None)[source]¶
Bases:
ModuleBatch Normalization.
Normalizes over the batch and spatial dimensions, maintaining running statistics for inference. During training, computes batch statistics and updates the running averages. During inference, uses the stored running statistics.
The input is expected to have shape
(batch, *spatial_dims, features). Normalization is applied over all axes except the last (features) axis.Example:
bn = BatchNorm(graph.child('bn')) # Training: uses batch statistics and updates running stats. y, params = bn(params, x, is_training=True) # Inference: uses running statistics. y, params = bn(params, x, is_training=False)
Pooling¶
- blox.max_pool(inputs: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] | None = None, padding: str | Sequence[tuple[int, int]] = 'VALID') Array[source]¶
Applies max pooling over spatial dimensions.
The number of spatial dimensions is inferred from
window_shape:1-tuple or int: 1D pooling (single spatial dimension).
2-tuple: 2D pooling (height, width).
3-tuple: 3D pooling (depth, height, width).
Supports arbitrary batch dimensions (0 or more). Uses channels-last convention:
(*batch, *spatial_dims, channels).- Parameters:
inputs – Input array with shape
(*batch, *spatial_dims, channels).window_shape – Shape of the pooling window as a tuple, determining the number of spatial dimensions. For 1D pooling, an int can be used.
strides – Stride of the pooling. If None, uses window_shape (no overlap).
padding – Padding mode. Either ‘SAME’, ‘VALID’, or a sequence of (low, high) padding pairs for each spatial dimension.
- Returns:
Pooled output array.
Example:
# 2x2 max pooling with stride 2. y = max_pool(x, window_shape=(2, 2), strides=2)
- blox.min_pool(inputs: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] | None = None, padding: str | Sequence[tuple[int, int]] = 'VALID') Array[source]¶
Applies min pooling over spatial dimensions.
The number of spatial dimensions is inferred from
window_shape:1-tuple or int: 1D pooling (single spatial dimension).
2-tuple: 2D pooling (height, width).
3-tuple: 3D pooling (depth, height, width).
Supports arbitrary batch dimensions (0 or more). Uses channels-last convention:
(*batch, *spatial_dims, channels).- Parameters:
inputs – Input array with shape
(*batch, *spatial_dims, channels).window_shape – Shape of the pooling window as a tuple, determining the number of spatial dimensions. For 1D pooling, an int can be used.
strides – Stride of the pooling. If None, uses window_shape (no overlap).
padding – Padding mode. Either ‘SAME’, ‘VALID’, or a sequence of (low, high) padding pairs for each spatial dimension.
- Returns:
Pooled output array.
- blox.avg_pool(inputs: Array, window_shape: int | Sequence[int], strides: int | Sequence[int] | None = None, padding: str | Sequence[tuple[int, int]] = 'VALID') Array[source]¶
Applies average pooling over spatial dimensions.
The number of spatial dimensions is inferred from
window_shape:1-tuple or int: 1D pooling (single spatial dimension).
2-tuple: 2D pooling (height, width).
3-tuple: 3D pooling (depth, height, width).
Supports arbitrary batch dimensions (0 or more). Uses channels-last convention:
(*batch, *spatial_dims, channels).Note
When
padding='SAME'the average is computed over the valid (non-padded) pixels in the window, ignoring the zeros added by padding.- Parameters:
inputs – Input array with shape
(*batch, *spatial_dims, channels).window_shape – Shape of the pooling window as a tuple, determining the number of spatial dimensions. For 1D pooling, an int can be used.
strides – Stride of the pooling. If None, uses window_shape (no overlap).
padding – Padding mode. Either ‘SAME’, ‘VALID’, or a sequence of (low, high) padding pairs for each spatial dimension.
- Returns:
Pooled output array.
Example:
# 2x2 average pooling with stride 2. y = avg_pool(x, window_shape=(2, 2), strides=2)
Sequence Processing¶
- class blox.SequenceBase(graph: Graph)[source]¶
Bases:
Module,Generic[InputsT,StateT,OutputsT,ResetT]Base class for sequence-processing modules.
This abstract class defines the interface for modules that process sequences. It supports both ‘chunk’ processing (e.g., Transformers) and ‘step’ processing (e.g., RNNs). Unlike the base Module, SequenceBase enforces a specific call signature.
The primary method is __call__ for single-step processing. For sequence processing, use apply which internally uses static_scan or dynamic_scan.
- abstractmethod apply(params: Params, inputs: InputsT, prev_state: StateT | None = None, is_reset: ResetT | None = None, is_training: bool = True) tuple[tuple[OutputsT, StateT], Params][source]¶
Processes a sequence of data [Batch, Time, …].
This method processes entire sequences, either step-by-step, which is the default behavior of RNN modules (see RecurrenceBase below), or in full, which is the default behavior of Transformer and Attention modules.
- Parameters:
params – The parameters container.
inputs – The input sequence Pytree. Leaves should have shape [Batch, Time, …].
prev_state – Optional initial state. If None, initial_state is called.
is_reset – Optional reset signal Pytree. Leaves should have shape [Batch, Time].
is_training – Boolean flag indicating if the model is in training mode.
- Returns:
A nested tuple ((outputs, final_state), updated_params).
- abstractmethod initial_state(params: Params, inputs: InputsT) tuple[StateT, Params][source]¶
Computes the initial state for the sequence processing.
- Parameters:
params – The parameters container.
inputs – The input Pytree. Used to infer batch size or other structural properties.
- Returns:
A tuple containing the initial state and the parameters container.
- class blox.RecurrenceBase(graph: Graph, is_static: bool = False)[source]¶
Bases:
SequenceBase[InputsT,StateT,OutputsT,ResetT]Base class for Recurrent Neural Networks (RNNs).
Implements sequence processing apply by applying the __call__ method step-by-step (using either static or dynamic scan).
Subclasses must implement: - initial_state: Returns the initial hidden state. - __call__: Processes a single time step.
- apply(params: Params, inputs: InputsT, prev_state: StateT | None = None, is_reset: ResetT | None = None, is_training: bool = True) tuple[tuple[OutputsT, StateT], Params][source]¶
Processes a sequence by scanning over __call__.
This method automatically handles initialization: if parameters are not yet locked (initialized), it forces a single-step execution expanded to the full sequence length to safely create parameters without violating JAX scan invariants.
- Parameters:
params – The parameters container.
inputs – The input sequence Pytree. Leaves must have shape [Batch, Time, …].
prev_state – Optional initial state. If None, initial_state is called.
is_reset – Optional reset signal Pytree. Leaves must have shape [Batch, Time].
is_training – Boolean flag indicating if the model is in training mode.
- Returns:
A nested tuple ((outputs, final_state), updated_params).
- Raises:
ValueError – If inputs have rank < 2.
- maybe_reset_state(params: Params, prev_state: StateT, inputs: InputsT, is_reset: ResetT | None = None) StateT[source]¶
Helper to reset state based on boolean signal.
- Parameters:
params – The parameters container.
prev_state – The current state Pytree.
inputs – The current input step. Used to infer batch size for fresh state.
is_reset – A boolean Pytree indicating which batch elements to reset.
- Returns:
The updated state with resets applied where indicated.
- class blox.LSTM(graph: Graph, hidden_size: int, rng: Rng | None, is_static: bool = False)[source]¶
Bases:
RecurrenceBase[Array,LSTMState,Array,Array]Long Short-Term Memory (LSTM) Recurrent Neural Network.
The mathematical definition of the cell is as follows:
\[\begin{split}i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c')\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
This module implements a standard LSTM cell. It inherits from RecurrenceBase, automatically providing support for both single-step execution (__call__) and efficient sequence processing (apply with scanning).
Example:
lstm = LSTM(graph.child('lstm'), hidden_size=128, rng=rng) # Initialize state first: state, params = lstm.initial_state(params, inputs) # Single-step processing (e.g., for interactive use): (outputs, state), params = lstm(params, inputs, state) # Sequence processing (uses jax.lax.scan for efficiency): (outputs, state), params = lstm.apply(params, inputs_sequence, state)
- initial_state(params: Params, inputs: Array) tuple[LSTMState, Params][source]¶
Creates the initial zero state.
- Parameters:
params – The parameters container.
inputs – The input array, used to infer the batch size (dimension 0).
- Returns:
A tuple (LSTMState, params), where both hidden and cell states are zeros.
- class blox.LSTMState(hidden: Array, cell: Array)[source]¶
Bases:
NamedTupleHolds the hidden and cell states for an LSTM.
Alias for field number 0
- class blox.GRU(graph: Graph, hidden_size: int, rng: Rng | None, is_static: bool = False)[source]¶
Bases:
RecurrenceBase[Array,GRUState,Array,Array]Gated Recurrent Unit (GRU).
The mathematical definition of the cell is as follows:
\[\begin{split}r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h\end{split}\]where x is the input and h is the output of the previous time step.
Example:
gru = GRU(graph.child('gru'), hidden_size=128, rng=rng) state, params = gru.initial_state(params, inputs) (outputs, state), params = gru(params, inputs, state)
- class blox.GRUState(hidden: Array)[source]¶
Bases:
NamedTupleHolds the hidden state for a GRU.
Alias for field number 0
- blox.static_scan(step_fn: Callable[[Params, InputsT, StateT, OutputsT | None, bool], tuple[tuple[ResetT, StateT], Params]], params: Params, inputs: InputsT, prev_state: StateT, is_reset: ResetT | None, is_training: bool) tuple[tuple[OutputsT, StateT], Params][source]¶
Performs a Python loop scan over the time dimension.
This function explicitly iterates over the time dimension (axis 1) of the inputs using a Python for loop. This is useful for debugging, handling control flow that jax.lax.scan cannot compile, or when the sequence length is very short.
- Parameters:
step_fn – A callable that processes a single time step.
params – The parameters container.
inputs – Input sequence Pytree [Batch, Time, …].
prev_state – Initial state.
is_reset – Optional reset signal [Batch, Time].
is_training – Training flag.
- Returns:
((outputs, final_state), updated_params)
- Raises:
ValueError – If inputs are empty or have invalid rank.
- blox.dynamic_scan(step_fn: Callable[[Params, InputsT, StateT, OutputsT | None, bool], tuple[tuple[ResetT, StateT], Params]], params: Params, inputs: InputsT, prev_state: StateT, is_reset: ResetT | None, is_training: bool) tuple[tuple[OutputsT, StateT], Params][source]¶
Performs a compiled jax.lax.scan over the time dimension.
This uses XLA compilation for high performance on long sequences.
- Parameters:
step_fn – A callable that processes a single time step.
params – The parameters container.
inputs – Input sequence Pytree [Batch, Time, …].
prev_state – Initial state.
is_reset – Optional reset signal [Batch, Time].
is_training – Training flag.
- Returns:
((outputs, final_state), updated_params)
- Raises:
ValueError – If inputs have invalid rank.
Visualization¶
- blox.display(graph: Graph | tuple[Graph, ...], params: Params | None = None) None[source]¶
Display model structure and parameters as an interactive tree.
Builds a visual tree showing: - Module hierarchy with type names - Parameter counts and memory usage (if params provided) - Constructor arguments - Parameter shapes, dtypes, and value statistics (if params provided) - References between modules (dependency injection)
- Parameters:
graph – Root Graph node(s). Pass a tuple to display multiple graphs together in a single view.
params – Optional Params container. If None, shows only the module hierarchy and constructor arguments (structure-only mode).
Example:
# Full display with params: bx.display(graph, params) # Structure only (no params): bx.display(graph) # Multiple graphs in one view: bx.display((encoder_graph, decoder_graph), params)