Design Principles¶
blox unlocks the full potential of JAX by embracing its functional nature instead of fighting it.
The Mental Model¶
The entire mental model fits in one line:
outputs, params = model(params, inputs)
Parameters go in, outputs and updated parameters come out. This is the standard pattern for stateful computations in JAX. Because state flows explicitly through your code, all JAX transformations work out of the box.
Functional Purity¶
Most JAX neural network libraries try to force Object-Oriented paradigms to make JAX feel like PyTorch, usually by introducing implicit global state, hidden contexts, or clever magic.
blox takes the opposite approach. Instead of hiding JAX’s functional approach, it leans into it.
Models are stateless: A model is just a transformation definition.
State is explicit: Parameters are passed explicitly as arguments.
No side effects: Every function returns
(outputs, params).
Core Abstractions¶
We separate “structure” from “state”.
Graph¶
A Graph object represents the hierarchical structure of your model (e.g., net -> mlp -> dense1). It provides unique namespaces (paths) for parameters.
Key points:
Paths are stored as tuples (e.g.,
('net', 'mlp', 'kernel')), allowing any characters in namesUse
graph.child('name')orgraph / 'name'to create child nodesModules bind to graph nodes and cannot bind to root nodes (must use
graph.child())Use
graph.walk()to iterate over all descendant modules
Params¶
A Params object is a flat, immutable container holding all state (weights, RNG keys, batch stats). It is keyed by the paths defined in the Graph.
Key points:
All state is stored as
Paramobjects with atrainableflagUse
params.locked()after initialization to prevent accidental param creationUse
params.split()to separate trainable from non-trainable state
Rng¶
An Rng module handles randomness. It is passed to modules on construction and stores its state (seed, counter) in the Params container.
Key points:
Must be seeded first via
rng.seed(params, seed=42)Uses counter-based key generation:
new_key = jax.random.fold_in(seed, counter)Modules that need randomness (e.g.,
Dropout) accept anRngon constructionIn
vmap/shard_map, users must manually fold in axis indices for unique randomness (see below)
RNG in Parallel Contexts¶
In vmap or shard_map, all lanes share the same seed and counter, so they get identical random keys by default. To get unique randomness per lane, fold in a unique index.
Simplest approach - pass the index explicitly:
def apply_with_explicit_index(params, x, batch_idx):
original_seed = rng.get_seed(params)
folded_seed = jax.random.fold_in(original_seed, batch_idx)
params = rng.seed(params, seed=folded_seed)
out, params = dropout(params, x, is_training=True)
params = rng.seed(params, seed=original_seed) # Restore for replicated params
return out, params
# Pass jnp.arange(batch_size) as the batch indices
jax.vmap(apply_with_explicit_index, in_axes=(None, 0, 0))(params, x, jnp.arange(4))
Or use jax.lax.axis_index with axis_name. This is the recommended approach as it’s more idiomatic:
def apply_with_axis_index(params, x):
original_seed = rng.get_seed(params)
folded_seed = jax.random.fold_in(original_seed, jax.lax.axis_index('batch'))
params = rng.seed(params, seed=folded_seed)
out, params = dropout(params, x, is_training=True)
params = rng.seed(params, seed=original_seed)
return out, params
jax.vmap(apply_with_axis_index, in_axes=(None, 0), axis_name='batch')(params, x)
Why restore the seed? When params is replicated (out_axes=None), JAX requires all lanes to return identical pytrees. Since we run the same function in each lane, the counter increments identically everywhere. The seed is the only thing that differs (due to folding), so restoring it ensures params match.
Init vs Runtime: During init (params unlocked), don’t fold: you want identical params. During runtime (params locked), do fold for unique randomness. Use params.is_locked to detect which mode.
Initialization with Sharded Parameters¶
When initializing models with sharded parameters, use jax.jit with out_shardings rather than shard_map.
Why shard_map is tricky for initialization:
Different parameters may need different axes folded in (e.g., model axis but not batch axis)
Multiple model axes mean different params have different sharding requirements
Managing which axes to fold for which params becomes complex
Why jax.jit is better:
Just specify
out_shardingsand JIT handles partitioningUse replicated RNG params during init
JIT is smart about parameter placement during initialization
# RECOMMENDED: Initialize via JIT with out_shardings
mesh = jax.make_mesh((4,), ('model',))
params_sharding = ... # Build from param metadata
@jax.jit(out_shardings=params_sharding)
def init():
params = rng.seed(bx.Params(), seed=42)
_, params = model(params, dummy_input)
return params.locked()
params = init() # JIT handles sharding automatically
Note: These RNG patterns are a sharp edge in JAX that blox inherits. We hope JAX will provide better primitives in the future.
The Params Container¶
The Params container holds all model state in one place: weights, RNG state, batch norm statistics, moving averages: everything.
Why put RNG in Params? In pure functional programming, randomness is state. If your dropout layer consumes a random key, that’s a state change. By threading RNG through Params, the signature outputs, params = model(params, inputs) tells the whole truth.
Every parameter is either:
Type |
Examples |
Gradients? |
Updated during forward? |
|---|---|---|---|
Trainable |
weights, embeddings |
Yes |
No |
Non-trainable |
RNG counters, batch stats, EMA |
No |
Yes |
The params.split() method separates these two categories for training.
JAX Compatibility¶
blox works with all JAX transformations out of the box:
jax.jit- Just wrap and calljax.grad- Useparams.split()to separate trainable paramsjax.vmap- Params can be replicated (in_axes=None) or batchedjax.shard_map- Use parameter metadata for sharding specsjax.checkpoint- Works with the functional state pattern
No special wrappers or decorators are required.
Training Pattern¶
The standard training pattern uses split() and merge():
@jax.jit(donate_argnames='params')
def train_step(params, inputs, targets):
# Split into trainable (weights) and non-trainable (RNG, etc).
trainable, non_trainable = params.split()
def loss_fn(t, nt):
# Merge to run the forward pass.
preds, new_params = model(t.merge(nt), inputs)
loss = jnp.mean((preds - targets) ** 2)
# Extract non-trainable parameters updated during the forward pass.
_, new_nt = new_params.split()
return loss, new_nt
# Gradients for trainable, updated state for non-trainable.
grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(
trainable, non_trainable
)
# Update the trainable parameters using SGD.
new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)
# Merge updated trainable and updated non-trainable parameters.
return new_trainable.merge(new_non_trainable)
Lazy Initialization¶
blox uses lazy initialization:
Define model structure abstractly (no memory allocation)
Run a forward pass to trigger parameter creation
Lock params to prevent accidental creation during training
# Define structure.
graph = bx.Graph('net')
rng = bx.Rng(graph.child('rng'))
model = MLP(graph.child('mlp'), hidden_size=128, output_size=10, rng=rng)
# Initialize.
params = bx.Params()
params = rng.seed(params, seed=42)
_, params = model(params, dummy_input)
params = params.locked()
You can also use jax.eval_shape() to get parameter structure without allocating memory, useful for setting up sharding.
Why the Verbosity?¶
blox chooses clarity over brevity.
Most frameworks rely on implicit global state or thread-local contexts to hide parameters. While this saves a few keystrokes, it creates a “black box.”
OOP-style Wrappers |
blox |
|---|---|
|
|
Implicit global state |
Explicit state passing |
Opaque variable scopes |
Explicit |
Custom |
Standard |
By accepting slightly more verbose function signatures, you gain:
Total transparency: You know exactly what data your function touches.
Full control: No global state means no unknown side-effects.
Maximum performance: Zero overhead.
Decoupled Params and Graph¶
A key design principle in blox is the clean separation between parameters and the model graph. This is different from frameworks like Flax or Equinox where parameters are tightly coupled to modules.
Why this separation?
No single module owns params. Parameters are passed into functions, not stored in modules. This enables multiple models to operate on the same params without ownership conflicts.
Avoids pytree complexity. When modules are pytrees containing both static configuration and JAX arrays, you get finnicky behavior requiring magic handling and special wrappers. blox keeps a clean split: the
Graphis purely static (Python objects describing structure), whileParamsis purely dynamic (JAX arrays ready for transformations).Graph is static, params are dynamic. The graph describes what operations to perform. The params provide what values to use. This separation is maintained throughout execution; graph structure doesn’t change at runtime, only param values do.
Multiple models can share the same params:
# Create two LSTM variants with the same parameter structure
def create_lstm(is_static: bool):
graph = bx.Graph('model')
rng = bx.Rng(graph.child('rng'))
lstm = bx.LSTM(graph.child('lstm'), hidden_size=64, rng=rng, is_static=is_static)
return lstm
lstm_static = create_lstm(is_static=True) # Python loop (debuggable)
lstm_dynamic = create_lstm(is_static=False) # lax.scan (fast)
# Initialize params once
params = ...
# Both models work with the same params!
out_static, _ = lstm_static.apply(params, inputs, prev_state=state)
out_dynamic, _ = lstm_dynamic.apply(params, inputs, prev_state=state)
Use cases enabled by this design:
Actor/Learner pattern: In RL, the Actor (collects data) and Learner (updates weights) can be separate models sharing params.
Train/Eval pattern: Training model (with dropout) and evaluation model (without) share the same params.
Design recommendations:
Create static models with deliberate purpose. Use factory functions to create model variants at construction time. Each model’s purpose should be clear from how it’s built, not from runtime modifications.
Avoid runtime graph mutation. While users have freedom to manipulate the graph, a clean design creates different models for different purposes rather than mutating one model at runtime.