Comparison: Equinox & Flax NNX¶
JAX already ships a strong abstraction, composable transformations over pure functions, with state threaded explicitly through function signatures. The question a neural-network library answers is how much of that stays visible. Equinox and Flax NNX are both mature, well-built libraries that take different routes, and both make simple things easy while adding cognitive overhead as models get complicated. This page documents where each one departs from JAX’s “explicit, no hidden state, no magic” core, and how blox avoids those departures by construction.
A note on fairness up front. Neither library is bad. Equinox stays close to JAX and is explicitly designed to “feel like (and be fully compatible with) the main JAX library itself.” NNX’s reference semantics make model surgery and inspection genuinely easy, and it has a large community and a strong development team behind it. The argument below is about where the abstraction leaks, not about whether these libraries are good.
Equinox: modules as pytrees, plus a filtering layer¶
Equinox keeps the model as a PyTree, which is close to JAX in spirit. But
because the module pytree mixes array leaves with arbitrary Python (non-array)
leaves, plain jax.jit / jax.grad don’t apply directly to a realistic
model. From the Equinox FAQ:
“This error happens because a model, when treated as a PyTree, may have leaves that are not JAX types (such as functions). It only makes sense to trace arrays. … Instead of
jax.jit, useequinox.filter_jit. Likewise for other transformations.”
So in practice you adopt a parallel family of filtered transformations
(filter_jit / filter_grad / filter_vmap / filter_pmap), and the
underlying ceremony is partition → transform → combine. You also learn
which default filter applies: is_array (all arrays) vs is_inexact_array
(only float arrays get gradients), since filter_grad silently returns
None gradients for non-float leaves.
Coupling state to the module pytree has real foot-guns. With BatchNorm, the
module’s flattened pytree form changes across iterations (equinox#238), and tied/shared
layers diverge under value semantics; the FAQ notes that after some gradient
updates “you’ll find that self.linear1 and self.linear2 are now
different.” Stateful layers must be threaded by hand: x, state =
self.norm(x, state).
Flax NNX: a separate reimplementation of JAX’s transforms¶
Flax NNX goes the other way, with mutable, PyTorch-style module objects and reference semantics. Because JAX transforms can’t operate on mutable reference-semantic objects, NNX ships its own version of essentially the entire JAX transform surface. From the NNX Transformations guide:
“JAX transformations operate on pytrees of
jax.Array\ s and abide by value semantics. This presents a challenge for Flax NNX, which representsnnx.Module\ s as regular Python objects that follow reference semantics.”
The result is a full parallel transform suite (nnx.jit, nnx.grad,
nnx.value_and_grad, nnx.vmap, nnx.pmap, nnx.scan, nnx.remat,
nnx.cond, nnx.while_loop, and more), described in the
NNX basics as
“supersets of their equivalent JAX counterparts.” It is a divergent
reimplementation, not a thin wrapper. nnx.scan “(consciously) deviates from
jax.lax.scan.” Underneath sits a Module / State / GraphDef
system (NNX is a graph, not a pytree, by design), and crossing into a JAX
transform requires a split / merge ceremony.
This is a real maintenance surface and a known leaky abstraction. Mutation
“must be used with care because it can clash with JAX’s underlying assumptions”
(changing graph structure inside nnx.jit “causes continuous
recompilations”), and the maintainers’ own long-term plan is to make NNX
implement the pytree protocol specifically so it can be
“used with raw JAX transformations and other libraries”, i.e. to stop
diverging from raw JAX.
The pattern: simple is easy, complex gets steep¶
Both libraries make the toy case easy and add overhead once you leave it.
Equinox’s surface grows with custom initialisation, parameter surgery, and
scan-over-layers (“The above code probably seems a bit complicated!”, the
Equinox tricks page). NNX’s
mutation (the thing that buys the PyTorch-like ergonomics) is also the thing
that introduces a new error class (Inconsistent aliasing detected) and the
silent-recompile trap. In both cases the cost lands at the edges, where you
cross a transform boundary or step outside the happy path.
How blox avoids both boundaries¶
blox keeps Graph (static Python describing structure) and Params
(dynamic JAX arrays) as two separate objects. There is no mixed-leaf pytree
to filter (as in Equinox) and no mutable graph to split and merge (as in NNX).
jax.jit, jax.grad, jax.vmap, and jax.checkpoint apply to a plain
function with no filter_* and no nnx.* wrapper in the way. See
Design Principles for the full rationale behind the params/graph split.
Randomness is the one sharp edge every JAX library inherits, and the three
handle it differently. Equinox threads jax.random keys through your
functions by hand. NNX hides them inside a stateful nnx.Rngs object whose
keys live in the graph and need “extra tricks with nnx.vmap” (per the Flax
randomness guide)
to behave correctly under transforms, namely nnx.split_rngs and
nnx.StateAxes when you vmap or scan. blox keeps JAX’s own counter-based fold_in
pattern and surfaces it explicitly rather than wrapping it (see
Sharp Bits / the RNG notes). The sharp edge is JAX’s, and so is
everything you learn working around it.
Equinox |
Flax NNX |
blox |
|
|---|---|---|---|
Calls JAX transforms directly |
Mostly, but non-array leaves force |
No, |
Yes, unwrapped ``jax.jit`` / ``grad`` / ``vmap`` / ``checkpoint`` |
Boundary ceremony |
|
|
None, ``params`` is already a clean array pytree |
Where state lives |
In the module pytree |
In mutable |
In a separate ``Params`` container |
Library-specific transforms to maintain |
filtered transforms |
a full parallel transform suite |
Zero |
Randomness |
manual |
stateful |
JAX’s own ``fold_in`` pattern, surfaced |
Main foot-gun |
value-semantics surprises (shared layers, BatchNorm tree drift) |
mutation vs |
JAX PRNG folding, surfaced not hidden |
Both Equinox and NNX are mature and a great fit for many projects: Equinox if
you like “the model is a pytree,” NNX if mutable PyTorch-style objects feel
natural. blox makes a different bet. Rather than building a framework on top
of JAX, it grows directly out of JAX’s own philosophy: explicit state, pure
functions, no hidden magic. The graph and the parameters stay separate, every
transformation is the real jax.* one, and the randomness is JAX’s own. What
you learn using blox is JAX itself, so your understanding and your code keep
paying off as the ecosystem moves, with nothing library-specific standing in
the way.