Open in Colab View on GitHub

Sharp Bits

This notebook covers common pitfalls when using blox with JAX. Understanding these will help you write correct, efficient code.

blox extends JAX’s stateful pattern. JAX recommends explicit state passing for stateful computations. blox makes this pattern ergonomic for neural networks: outputs, params = model(params, inputs). All state flows explicitly through function signatures.

Contents

  1. The Params Container — locked vs unlocked, immutability, split/merge, memory ownership.

  2. RNG Handling in Parallel Contexts — unique randomness in vmap / shard_map.

  3. Init vs Runtime — the is_locked pattern, sharded initialization.

  4. Graph and Module Modifications — JIT caching, static models.

  5. Summary.

[1]:
import sys

sys.path.insert(0, '../src')

import blox as bx
import jax
import jax.numpy as jnp

1. The Params Container

The Params container holds all model state: weights, RNG counters, batch norm statistics, and any other values that need to be threaded through computations. Understanding how it works is fundamental to using blox correctly.

Locked vs Unlocked

The Params container operates in two modes that correspond to different phases of model usage:

Unlocked (default): New parameters can be created via module.get_param. This is the initialization phase where the model discovers its parameter structure by running a forward pass. When a module calls get_param for a parameter that doesn’t exist yet, the parameter is created and added to params.

Locked (after ``.locked()``): No new parameters can be added. Any attempt to create a new parameter will raise an error. This is the runtime phase where you’re training or evaluating the model. Locking params catches bugs where you accidentally try to create new parameters during training—for example, if you have a typo in a parameter name, the error will be immediate rather than silently creating a duplicate parameter.

The locked/unlocked distinction is also used internally to detect whether code is running during initialization or runtime, which matters for RNG handling (covered in the next section).

[2]:
# Create a simple model.
graph = bx.Graph('demo')
rng = bx.Rng(graph.child('rng'))
linear = bx.Linear(graph.child('linear'), output_size=4, rng=rng)

# Start with unlocked params.
params = rng.seed(bx.Params(), seed=42)
print(f'Initial state: is_locked={params.is_locked}, num_params={len(params)}')

# First forward pass creates the parameters lazily.
_, params = linear(params, jnp.ones((1, 3)))
print(f'After forward: is_locked={params.is_locked}, num_params={len(params)}')

# Lock for runtime use.
params = params.locked()
print(f'After locking: is_locked={params.is_locked}')

# Subsequent forward passes work fine with existing params.
out, _ = linear(params, jnp.ones((1, 3)))
print(f'Forward pass works: output shape {out.shape}')
Initial state: is_locked=False, num_params=2
After forward: is_locked=False, num_params=4
After locking: is_locked=True
Forward pass works: output shape (1, 4)

If you try to create a new parameter with locked params, you get an error:

[4]:
# This would fail because params is locked.
# Uncommenting the next line would raise: RuntimeError: Params is locked.

# another_linear = bx.Linear(graph.child('another'), output_size=8, rng=rng)
# _, params = another_linear(params, jnp.ones((1, 4)))  # Error!

# If you need to add more params, unlock first.
params_unlocked = params.unlocked()
another_linear = bx.Linear(graph.child('another'), output_size=8, rng=rng)
_, params_unlocked = another_linear(params_unlocked, jnp.ones((1, 4)))
print(f'After unlocking and adding: num_params={len(params_unlocked)}')
After unlocking and adding: num_params=6

Immutability

blox’s Params is immutable by design. Every operation that modifies params returns a new params object, leaving the original unchanged. This is a deliberate design choice that:

  • Makes state changes explicit and trackable in your code.

  • Avoids subtle bugs from shared mutable state.

  • Works naturally with JAX’s functional transformations like jax.grad.

Always capture the returned params from operations:

[8]:
# Functional update pattern: set_param returns a NEW params object.
old_kernel, _ = linear.get_param(params, 'kernel')
new_kernel = old_kernel * 2

params_modified = linear.set_param(params, 'kernel', new_kernel)

kernel_path = linear.param_path('kernel')
print(f'Original kernel sum: {params[kernel_path].value.sum():.4f}')
print(f'Modified kernel sum: {params_modified[kernel_path].value.sum():.4f}')
print('Original params is unchanged!')
Original kernel sum: 0.5145
Modified kernel sum: 1.0290
Original params is unchanged!

Split and Merge for Gradient Computation

During training, you need gradients for trainable parameters (weights, biases) but also need to propagate updates to non-trainable state (RNG counters, batch norm running statistics). The params.split() method separates these two categories:

  • Trainable: Parameters that receive gradients (marked with trainable=True).

  • Non-trainable: State that updates during forward passes but shouldn’t receive gradients.

After computing gradients and applying updates, use .merge() to recombine them:

[10]:
trainable, non_trainable = params.split()

print('Trainable (receive gradients):')
for path in trainable.keys():
  print(f'  {path}')

print('\nNon-trainable (state propagation only):')
for path in non_trainable.keys():
  print(f'  {path}')
Trainable (receive gradients):
  ('demo', 'linear', 'kernel')
  ('demo', 'linear', 'bias')

Non-trainable (state propagation only):
  ('demo', 'rng', 'seed')
  ('demo', 'rng', 'counter')

The standard training pattern uses split/merge to handle both gradient computation and state propagation:

def train_step(params, inputs, targets):
  trainable, non_trainable = params.split()

  def loss_fn(t, nt):
    # Merge to run the forward pass.
    preds, new_params = model(t.merge(nt), inputs)
    loss = jnp.mean((preds - targets) ** 2)
    # Extract updated non-trainable state.
    _, new_nt = new_params.split()
    return loss, new_nt

  grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(trainable, non_trainable)
  new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)
  return new_trainable.merge(new_non_trainable)

Params Ownership and Memory

Other neural network libraries typically hide parameters inside the model pytree or in thread-local context. This allows users to work with a single, implicit copy of params—convenient, but opaque. For example, when switching a model to “eval mode,” it’s not always clear which parameters will be updated (batch norm running stats, dropout state) and which will remain unchanged, especially when building custom modules.

blox takes a different approach: params are explicit and user-controlled. You see params, you pass params, you get params back. This aligns with blox’s philosophy that explicit is better than implicit—you have full visibility and control over all state.

However, explicit ownership requires care:

  • Avoid keeping references to old params. Since JAX arrays live on device memory, holding references prevents garbage collection. This is especially problematic in loops where storing previous params accumulates memory usage.

  • Donate params to JIT. Using donate_argnames=['params'] allows JAX to reuse param buffers in-place rather than allocating new ones. This is a significant memory optimization that explicit ownership enables.

The tradeoff is intentional: blox gives you flexibility for advanced use cases (branching params for evaluation, keeping param snapshots, using the same params across different model variants) while requiring you to be deliberate about memory. Other libraries manage this automatically but opaquely—you don’t know when buffers are shared or copied. With blox, you’re in control.


2. RNG Handling in Parallel Contexts

How blox Handles Randomness

blox uses a seed-and-counter approach for random number generation. The Rng module stores a seed and a counter in params. Each call to rng(params) returns a key and increments the counter:

key = jax.random.fold_in(seed, counter)
counter += 1

This gives deterministic, reproducible randomness while threading state functionally through your computations.

The Problem: Replicated RNG State

blox recommends replicating the RNG state across parallel lanes:

  • For vmap, we typically replicate all of params (using in_axes=(None, 0)).

  • For shard_map, it depends on sharding, but RNG state should generally be replicated.

When RNG state is replicated, all lanes share the same seed and counter, so they compute identical random keys. This is problematic for operations like dropout where each batch item should receive a unique mask—otherwise all items get dropped in the same positions, defeating the purpose of dropout.

[11]:
# Setup: create a dropout layer.
graph2 = bx.Graph('rng_demo')
rng2 = bx.Rng(graph2.child('rng'))
dropout = bx.Dropout(graph2.child('dropout'), rate=0.5, rng=rng2)

# Initialize and lock.
params2 = rng2.seed(bx.Params(), seed=42)
_, params2 = dropout(params2, jnp.ones((1, 8)), is_training=True)
params2 = params2.locked()


def apply_dropout_naive(params, x):
  """Naive approach with no axis folding."""
  out, params = dropout(params, x, is_training=True)
  return out, params


# vmap over 4 samples with replicated params.
x_batch = jnp.ones((4, 8))
out_naive, _ = jax.vmap(
    apply_dropout_naive,
    in_axes=(None, 0),  # None means params is replicated.
    out_axes=(0, None),
)(params2, x_batch)

print('Problem: all lanes get the SAME dropout mask!')
for i in range(4):
  mask = (out_naive[i] == 0).astype(int).tolist()
  print(f'  Lane {i}: {mask}')
Problem: all lanes get the SAME dropout mask!
  Lane 0: [1, 1, 0, 0, 0, 0, 0, 0]
  Lane 1: [1, 1, 0, 0, 0, 0, 0, 0]
  Lane 2: [1, 1, 0, 0, 0, 0, 0, 0]
  Lane 3: [1, 1, 0, 0, 0, 0, 0, 0]

The Solution: Fold in the Lane Index

To get unique randomness per lane, fold the lane index into the seed before using it. This creates a unique derived seed for each lane while keeping the counter synchronized.

Approach 2 (using ``axis_index``) is recommended as it’s more idiomatic—you don’t need to pass extra arguments through your function signatures. However, approach 1 is useful when you need the batch index for other purposes.

[12]:
# Approach 1: Pass the batch index explicitly.
def apply_dropout_explicit(params, x, batch_idx):
  """Fold in an explicit batch index for unique randomness."""
  original_seed = rng2.get_seed(params)
  folded_seed = jax.random.fold_in(original_seed, batch_idx)
  params = rng2.seed(params, seed=folded_seed)

  out, params = dropout(params, x, is_training=True)

  # Restore original seed so all lanes return identical params.
  params = rng2.seed(params, seed=original_seed)
  return out, params


# Reset counter for fair comparison.
params2 = rng2.seed(params2, counter=0)

batch_indices = jnp.arange(4)
out_explicit, _ = jax.vmap(
    apply_dropout_explicit,
    in_axes=(None, 0, 0),
    out_axes=(0, None),
)(params2, x_batch, batch_indices)

print('Solution: each lane gets a UNIQUE dropout mask!')
for i in range(4):
  mask = (out_explicit[i] == 0).astype(int).tolist()
  print(f'  Lane {i}: {mask}')
Solution: each lane gets a UNIQUE dropout mask!
  Lane 0: [0, 1, 0, 1, 1, 1, 0, 0]
  Lane 1: [0, 1, 1, 0, 0, 0, 1, 1]
  Lane 2: [0, 1, 0, 0, 1, 1, 1, 0]
  Lane 3: [0, 0, 0, 1, 0, 1, 0, 0]
[13]:
# Approach 2: Use jax.lax.axis_index with axis_name.
def apply_dropout_axis_index(params, x):
  """Use axis_index to get the lane index implicitly."""
  original_seed = rng2.get_seed(params)
  folded_seed = jax.random.fold_in(original_seed, jax.lax.axis_index('batch'))
  params = rng2.seed(params, seed=folded_seed)

  out, params = dropout(params, x, is_training=True)

  # Restore original seed.
  params = rng2.seed(params, seed=original_seed)
  return out, params


# Reset counter.
params2 = rng2.seed(params2, counter=0)

out_axis, _ = jax.vmap(
    apply_dropout_axis_index,
    in_axes=(None, 0),
    out_axes=(0, None),
    axis_name='batch',  # Required for axis_index to work.
)(params2, x_batch)

print(
    'Same result with axis_index (cleaner when index is not needed elsewhere):'
)
for i in range(4):
  mask = (out_axis[i] == 0).astype(int).tolist()
  print(f'  Lane {i}: {mask}')
Same result with axis_index (cleaner when index is not needed elsewhere):
  Lane 0: [0, 1, 0, 1, 1, 1, 0, 0]
  Lane 1: [0, 1, 1, 0, 0, 0, 1, 1]
  Lane 2: [0, 1, 0, 0, 1, 1, 1, 0]
  Lane 3: [0, 0, 0, 1, 0, 1, 0, 0]

Why Restore the Original Seed?

When params are replicated across lanes (out_axes=None for params), JAX requires all lanes to return identical pytrees. The counter increments identically in each lane (same operations happen in each), but the seed differs due to folding. By restoring the original seed at the end, we ensure the returned params are identical across all lanes, satisfying JAX’s requirement.


3. Init vs Runtime

The is_locked Pattern

There’s a fundamental distinction between initialization and runtime that affects how RNG should behave:

Initialization (params.is_locked == False): We’re creating parameters. When vmapped, we want identical params across lanes. If each lane initialized with different random values, the params wouldn’t match and couldn’t be merged. So during init, we should NOT fold in the lane index.

Runtime (params.is_locked == True): We’re using existing parameters. When vmapped, we want unique randomness per batch item for operations like dropout. So during runtime, we SHOULD fold in the lane index.

Use params.is_locked to write functions that behave correctly in both contexts:

[14]:
def forward(params, x):
  """Works correctly for both init and runtime."""
  original_seed = rng2.get_seed(params)

  if params.is_locked:
    # Runtime: fold in axis index for unique randomness per batch item.
    folded_seed = jax.random.fold_in(original_seed, jax.lax.axis_index('batch'))
    params = rng2.seed(params, seed=folded_seed)
  # Init (unlocked): don't fold, so all lanes create identical params.

  out, params = dropout(params, x, is_training=True)

  # Always restore seed so replicated params stay consistent.
  params = rng2.seed(params, seed=original_seed)
  return out, params


# Init phase: params are unlocked, no folding happens.
def init_fn(x):
  p = rng2.seed(bx.Params(), seed=42)
  _, p = forward(p, x)
  return p.locked()


params3 = jax.vmap(init_fn, axis_name='batch', out_axes=None)(x_batch)
print(f'Init complete: is_locked={params3.is_locked}')

# Runtime phase: params are locked, folding is applied.
out_runtime, _ = jax.vmap(
    forward,
    in_axes=(None, 0),
    out_axes=(0, None),
    axis_name='batch',
)(params3, x_batch)

print('Runtime: unique masks per lane!')
for i in range(4):
  mask = (out_runtime[i] == 0).astype(int).tolist()
  print(f'  Lane {i}: {mask}')
Init complete: is_locked=True
Runtime: unique masks per lane!
  Lane 0: [0, 0, 0, 0, 1, 0, 1, 1]
  Lane 1: [0, 1, 0, 1, 0, 0, 1, 0]
  Lane 2: [0, 0, 0, 1, 0, 0, 1, 0]
  Lane 3: [1, 1, 0, 0, 0, 0, 1, 1]

Sharded Initialization: JIT vs shard_map

Models can use jax.jit or jax.experimental.shard_map depending on the level of control you need over device placement. Both can shard initialization and runtime, but they have different tradeoffs.

The challenge with ``shard_map`` for initialization:

When using shard_map at runtime, sharded params go in and params with the same sharding come out. But how do you create those sharded params in the first place? The difficulty is that different parameters may need different sharding specifications, and managing which axes to fold for which params becomes complex.

Recommended approach: Initialize with ``jax.jit``

If you’re using shard_map at runtime, we recommend initializing params using plain jax.jit with out_shardings. JIT can:

  • Use a replicated RNG to create global (unsharded) parameters.

  • Automatically partition them according to the specified out_shardings.

mesh = jax.make_mesh((4,), ('model',))
params_sharding = ...  # Build from param metadata.

@jax.jit(out_shardings=params_sharding)
def init():
  params = rng.seed(bx.Params(), seed=42)
  _, params = model(params, dummy_input)
  return params.locked()

params = init()  # JIT handles sharding automatically.

This approach is simpler and less error-prone than trying to manage sharding manually in shard_map.


4. Graph and Module Modifications

The Problem: JIT Caching and Module State

When you JIT-compile a function, JAX traces it once and caches the compiled result. If you then modify a module’s properties, JAX will not automatically recompile—it uses the cached trace from before your modification. This leads to silent bugs where your changes appear to have no effect.

See JAX Gotchas: Using jax.jit with class methods for the underlying JAX behavior.

Let’s demonstrate with a simple custom module:

[15]:
class AddConstant(bx.Module):
  """Simple module that adds a constant to the input."""

  def __init__(self, graph: bx.Graph, value: float):
    super().__init__(graph)
    self.value = value  # Static config stored as Python float, not JAX array.

  def __call__(self, params: bx.Params, x: jax.Array):
    return x + self.value, params


graph4 = bx.Graph('modify_demo')
adder = AddConstant(graph4.child('adder'), value=10.0)
params4 = bx.Params().locked()
x = jnp.array([1.0, 2.0, 3.0])

# JIT-compile adder's __call__ function.
add_constant = jax.jit(adder)

# First call traces with value=10.
out1, _ = add_constant(params4, x)
print(f'value=10: {x.tolist()} -> {out1.tolist()}')

# Modify the module.
adder.value = 100.0
print(f'Changed value to {adder.value}')

# Second call uses CACHED trace—change is NOT visible!
out2, _ = add_constant(params4, x)
print(f'After change: {x.tolist()} -> {out2.tolist()}')
print('BUG: Still adding 10, not 100!')
value=10: [1.0, 2.0, 3.0] -> [11.0, 12.0, 13.0]
Changed value to 100.0
After change: [1.0, 2.0, 3.0] -> [11.0, 12.0, 13.0]
BUG: Still adding 10, not 100!

blox Philosophy: Prefer Static Models

JAX’s recommended solution is to register classes as pytrees so that changes trigger recompilation. However, this approach brings complexity when modules contain non-hashable types (like lists or dicts) and requires careful handling of what counts as “static” vs “dynamic”.

blox takes a different approach: keep graphs static.

Instead of modifying modules after creation, build separate models for different configurations. Since params are decoupled from the graph, multiple models can share the same params—you just need matching parameter paths.

Separation of concerns:

  • Graph describes what operations to perform (static Python objects).

  • Params holds what values to use (dynamic JAX arrays).

This separation means:

  • Graph should not contain JAX arrays in constructor arguments.

  • Modules should not create params in __init__ (wasteful memory, prevents lazy creation).

  • Use module.get_param() for lazy param creation during the first forward pass.

Pattern 1: Build Separate Static Models

If you need different configurations, create separate models from the start using a factory function. Each model gets its own JIT cache, avoiding confusion:

[ ]:
# Factory function creates fresh models with desired config.
def create_adder(value: float):
  graph = bx.Graph('adder')
  return AddConstant(graph.child('add'), value=value)


adder_10 = create_adder(10.0)
adder_100 = create_adder(100.0)


# JIT each module with donate_argnames for efficiency.
@jax.jit(donate_argnames=['params'])
def add_10(params, x):
  return adder_10(params, x)


@jax.jit(donate_argnames=['params'])
def add_100(params, x):
  return adder_100(params, x)


out_10, _ = add_10(params4, x)
out_100, _ = add_100(params4, x)

print(f'adder_10: {x.tolist()} -> {out_10.tolist()}')
print(f'adder_100: {x.tolist()} -> {out_100.tolist()}')
print('Both work correctly with static models!')

Pattern 2: Create Separate Functions

If you need to modify a module, define a new function that captures the modified state. Each new function gets its own JIT cache:

[ ]:
# Create a mutable module.
graph5 = bx.Graph('wrap_demo')
mutable_adder = AddConstant(graph5.child('adder'), value=10.0)


# Define and JIT a function for value=10.
@jax.jit(donate_argnames=['params'])
def apply_v10(params, x):
  return mutable_adder(params, x)


out1, _ = apply_v10(params4, x)
print(f'value=10: {x.tolist()} -> {out1.tolist()}')

# Modify the module.
mutable_adder.value = 100.0


# Define and JIT a NEW function for value=100.
@jax.jit(donate_argnames=['params'])
def apply_v100(params, x):
  return mutable_adder(params, x)


out2, _ = apply_v100(params4, x)
print(f'value=100: {x.tolist()} -> {out2.tolist()}')
print('Each function gets its own JIT cache!')

Sharing Params Across Models

Since params are decoupled from the graph, different models can use the same params as long as they have matching parameter paths. This enables powerful patterns:

[ ]:
# Create two LSTM variants with the same parameter structure.
def create_lstm(is_static: bool):
  graph = bx.Graph('model')
  rng = bx.Rng(graph.child('rng'))
  lstm = bx.LSTM(
      graph.child('lstm'), hidden_size=16, rng=rng, is_static=is_static
  )
  return rng, lstm


rng_static, lstm_static = create_lstm(is_static=True)  # Uses Python for-loop.
rng_dynamic, lstm_dynamic = create_lstm(is_static=False)  # Uses jax.lax.scan.

# Initialize params using one of them.
x = jnp.ones((2, 5, 8))  # Shape: [batch, time, features].
params5 = rng_static.seed(bx.Params(), seed=42)
state, params5 = lstm_static.initial_state(params5, x[:, 0])
(out, _), params5 = lstm_static.apply(params5, x, prev_state=state)
params5 = params5.locked()

# Both models work with the SAME params!
out_static, _ = lstm_static.apply(params5, x, prev_state=state)
out_dynamic, _ = lstm_dynamic.apply(params5, x, prev_state=state)

print(f'Static LSTM output sum: {out_static[0].sum():.4f}')
print(f'Dynamic LSTM output sum: {out_dynamic[0].sum():.4f}')
print('Same params, different execution modes!')

Use cases for shared params:

  • Actor vs Learner in RL: Separate models for data collection and training that share weights.

  • Training vs Evaluation: Different logic (e.g., dropout enabled/disabled) while using the same parameters.

  • Debug vs Production: Python loop for debugging with breakpoints, lax.scan for production speed.


5. Summary

Sharp Edges and Patterns

Issue

Pattern

Same random values in vmap

Fold in lane index; check is_locked for init vs runtime.

Sharded param initialization

Use jax.jit with out_shardings, not shard_map.

Creating params during runtime

Lock params after init with params.locked().

Module changes not visible after JIT

Build separate static models, or create fresh JIT wrappers.

Memory waste from old param references

Donate params to JIT; avoid storing references in loops.

Key Design Principles

Principle

Why It Matters

Params decoupled from Graph

Multiple models can share same params.

Graph is static

Avoids JIT caching bugs; clear separation of concerns.

Params are immutable

Explicit state changes; works with JAX transforms.

Lazy param creation

No wasted memory; no need to specify shapes upfront.

Locked vs unlocked

Catches bugs; enables init vs runtime RNG behavior.

Explicit params ownership

Full visibility and control; enables buffer donation.

References