blox¶
A functional and lightweight neural network library for JAX.
blox unlocks the full potential of JAX by embracing its functional nature instead of fighting it.
JAX gives you composable transformations that let you write math and have it run blazingly fast on any hardware. blox is a thin layer on top that keeps all of that power accessible while giving you just enough structure to organize your neural networks.
The entire mental model fits in one line:
outputs, params = model(params, inputs)
Parameters go in, outputs and updated parameters come out. Because state
flows explicitly through your code, all JAX transformations
(jax.jit, jax.grad, jax.vmap, jax.checkpoint) work out
of the box. No wrappers, no decorators, no surprises.
Who is blox for?¶
Learners. There is no “framework magic” to learn. What you see is what you get: the best way to understand how neural networks actually work at the JAX level.
Practitioners. If you’re tired of fighting frameworks that hide important details, blox gives you complete transparency. Whether you’re building custom training loops, implementing novel architectures, or scaling up, you have direct access to the full execution stack.
Installation¶
Since blox uses JAX, check the JAX installation guide for your specific hardware (CPU / GPU / TPU). You will need Python 3.11 or later.
Install the latest release from PyPI:
pip install jax-blox
Or install the development version from source:
pip install git+https://github.com/hamzamerzic/blox.git
Quickstart¶
Define a layer:
import jax
import jax.numpy as jnp
import blox as bx
class Linear(bx.Module):
def __init__(self, graph: bx.Graph, output_size: int, rng: bx.Rng):
super().__init__(graph)
self.output_size = output_size
self.rng = rng
def __call__(self, params: bx.Params, x: jax.Array):
# Parameters are created lazily on first use.
kernel, params = self.get_param(
params,
name='kernel',
shape=(x.shape[-1], self.output_size),
init=jax.nn.initializers.normal(),
rng=self.rng,
)
bias, params = self.get_param(
params,
name='bias',
shape=(self.output_size,),
init=jax.nn.initializers.zeros,
rng=self.rng,
)
return x @ kernel + bias, params
Wire it up and run a forward pass:
graph = bx.Graph('net')
rng = bx.Rng(graph.child('rng'))
model = Linear(graph.child('linear'), output_size=10, rng=rng)
params = bx.Params()
params = rng.seed(params, seed=42)
outputs, params = model(params, jnp.ones((4, 784)))
params = params.locked() # Lock to prevent accidental param creation.
JIT, grad, vmap: they all just work, because the function signature already tells the whole truth.
outputs, params = jax.jit(model)(params, inputs)
Continue with the MNIST tutorial for an end-to-end example, or read the design notes to see how the pieces fit together.