{ "cells": [ { "cell_type": "markdown", "id": "cell-0", "metadata": {}, "source": "# Sharp Bits in blox\n\nThis notebook covers common pitfalls when using blox with JAX. Understanding these will help you write correct, efficient code.\n\n**blox extends JAX's stateful pattern.** JAX recommends explicit state passing for [stateful computations](https://docs.jax.dev/en/latest/stateful-computations.html). blox makes this pattern ergonomic for neural networks: `outputs, params = model(params, inputs)`. All state flows explicitly through function signatures.\n\n## Contents\n\n1. [The Params Container](#1-the-params-container) - Locked vs unlocked, immutability, split/merge, memory ownership\n2. [RNG Handling in Parallel Contexts](#2-rng-handling-in-parallel-contexts) - Unique randomness in vmap/shard_map\n3. [Init vs Runtime](#3-init-vs-runtime) - The `is_locked` pattern, sharded initialization\n4. [Graph and Module Modifications](#4-graph-and-module-modifications) - JIT caching, static models\n5. [Summary](#5-summary)" }, { "cell_type": "code", "execution_count": 1, "id": "cell-1", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.insert(0, '../src')\n", "\n", "import blox as bx\n", "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "id": "cell-2", "metadata": {}, "source": [ "---\n", "\n", "# 1. The Params Container\n", "\n", "The `Params` container holds all model state: weights, RNG counters, batch norm statistics, and any other values that need to be threaded through computations. Understanding how it works is fundamental to using blox correctly.\n", "\n", "## Locked vs Unlocked\n", "\n", "The `Params` container operates in two modes that correspond to different phases of model usage:\n", "\n", "**Unlocked (default):** New parameters can be created via `module.get_param`. This is the initialization phase where the model discovers its parameter structure by running a forward pass. When a module calls `get_param` for a parameter that doesn't exist yet, the parameter is created and added to params.\n", "\n", "**Locked (after `.locked()`):** No new parameters can be added. Any attempt to create a new parameter will raise an error. This is the runtime phase where you're training or evaluating the model. Locking params catches bugs where you accidentally try to create new parameters during training—for example, if you have a typo in a parameter name, the error will be immediate rather than silently creating a duplicate parameter.\n", "\n", "The locked/unlocked distinction is also used internally to detect whether code is running during initialization or runtime, which matters for RNG handling (covered in the next section)." ] }, { "cell_type": "code", "execution_count": 2, "id": "cell-3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial state: is_locked=False, num_params=2\n", "After forward: is_locked=False, num_params=4\n", "After locking: is_locked=True\n", "Forward pass works: output shape (1, 4)\n" ] } ], "source": [ "# Create a simple model.\n", "graph = bx.Graph('demo')\n", "rng = bx.Rng(graph.child('rng'))\n", "linear = bx.Linear(graph.child('linear'), output_size=4, rng=rng)\n", "\n", "# Start with unlocked params.\n", "params = rng.seed(bx.Params(), seed=42)\n", "print(f'Initial state: is_locked={params.is_locked}, num_params={len(params)}')\n", "\n", "# First forward pass creates the parameters lazily.\n", "_, params = linear(params, jnp.ones((1, 3)))\n", "print(f'After forward: is_locked={params.is_locked}, num_params={len(params)}')\n", "\n", "# Lock for runtime use.\n", "params = params.locked()\n", "print(f'After locking: is_locked={params.is_locked}')\n", "\n", "# Subsequent forward passes work fine with existing params.\n", "out, _ = linear(params, jnp.ones((1, 3)))\n", "print(f'Forward pass works: output shape {out.shape}')" ] }, { "cell_type": "markdown", "id": "cell-4", "metadata": {}, "source": [ "If you try to create a new parameter with locked params, you get an error:" ] }, { "cell_type": "code", "execution_count": 4, "id": "cell-5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "After unlocking and adding: num_params=6\n" ] } ], "source": [ "# This would fail because params is locked.\n", "# Uncommenting the next line would raise: RuntimeError: Params is locked.\n", "\n", "# another_linear = bx.Linear(graph.child('another'), output_size=8, rng=rng)\n", "# _, params = another_linear(params, jnp.ones((1, 4))) # Error!\n", "\n", "# If you need to add more params, unlock first.\n", "params_unlocked = params.unlocked()\n", "another_linear = bx.Linear(graph.child('another'), output_size=8, rng=rng)\n", "_, params_unlocked = another_linear(params_unlocked, jnp.ones((1, 4)))\n", "print(f'After unlocking and adding: num_params={len(params_unlocked)}')" ] }, { "cell_type": "markdown", "id": "cell-6", "metadata": {}, "source": [ "## Immutability\n", "\n", "blox's `Params` is immutable by design. Every operation that modifies params returns a **new** params object, leaving the original unchanged. This is a deliberate design choice that:\n", "\n", "- Makes state changes explicit and trackable in your code.\n", "- Avoids subtle bugs from shared mutable state.\n", "- Works naturally with JAX's functional transformations like `jax.grad`.\n", "\n", "Always capture the returned params from operations:" ] }, { "cell_type": "code", "execution_count": 8, "id": "cell-7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original kernel sum: 0.5145\n", "Modified kernel sum: 1.0290\n", "Original params is unchanged!\n" ] } ], "source": [ "# Functional update pattern: set_param returns a NEW params object.\n", "old_kernel, _ = linear.get_param(params, 'kernel')\n", "new_kernel = old_kernel * 2\n", "\n", "params_modified = linear.set_param(params, 'kernel', new_kernel)\n", "\n", "kernel_path = linear.param_path('kernel')\n", "print(f'Original kernel sum: {params[kernel_path].value.sum():.4f}')\n", "print(f'Modified kernel sum: {params_modified[kernel_path].value.sum():.4f}')\n", "print('Original params is unchanged!')" ] }, { "cell_type": "markdown", "id": "cell-8", "metadata": {}, "source": [ "## Split and Merge for Gradient Computation\n", "\n", "During training, you need gradients for trainable parameters (weights, biases) but also need to propagate updates to non-trainable state (RNG counters, batch norm running statistics). The `params.split()` method separates these two categories:\n", "\n", "- **Trainable:** Parameters that receive gradients (marked with `trainable=True`).\n", "- **Non-trainable:** State that updates during forward passes but shouldn't receive gradients.\n", "\n", "After computing gradients and applying updates, use `.merge()` to recombine them:" ] }, { "cell_type": "code", "execution_count": 10, "id": "cell-9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trainable (receive gradients):\n", " ('demo', 'linear', 'kernel')\n", " ('demo', 'linear', 'bias')\n", "\n", "Non-trainable (state propagation only):\n", " ('demo', 'rng', 'seed')\n", " ('demo', 'rng', 'counter')\n" ] } ], "source": [ "trainable, non_trainable = params.split()\n", "\n", "print('Trainable (receive gradients):')\n", "for path in trainable.keys():\n", " print(f' {path}')\n", "\n", "print('\\nNon-trainable (state propagation only):')\n", "for path in non_trainable.keys():\n", " print(f' {path}')" ] }, { "cell_type": "markdown", "id": "cell-10", "metadata": {}, "source": [ "The standard training pattern uses split/merge to handle both gradient computation and state propagation:\n", "\n", "```python\n", "def train_step(params, inputs, targets):\n", " trainable, non_trainable = params.split()\n", "\n", " def loss_fn(t, nt):\n", " # Merge to run the forward pass.\n", " preds, new_params = model(t.merge(nt), inputs)\n", " loss = jnp.mean((preds - targets) ** 2)\n", " # Extract updated non-trainable state.\n", " _, new_nt = new_params.split()\n", " return loss, new_nt\n", "\n", " grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(trainable, non_trainable)\n", " new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)\n", " return new_trainable.merge(new_non_trainable)\n", "```" ] }, { "cell_type": "markdown", "id": "y0hh14jyndj", "source": "## Params Ownership and Memory\n\nOther neural network libraries typically hide parameters inside the model pytree or in thread-local context. This allows users to work with a single, implicit copy of params—convenient, but opaque. For example, when switching a model to \"eval mode,\" it's not always clear which parameters will be updated (batch norm running stats, dropout state) and which will remain unchanged, especially when building custom modules.\n\nblox takes a different approach: params are explicit and user-controlled. You see params, you pass params, you get params back. This aligns with blox's philosophy that **explicit is better than implicit**—you have full visibility and control over all state.\n\nHowever, explicit ownership requires care:\n\n- **Avoid keeping references to old params.** Since JAX arrays live on device memory, holding references prevents garbage collection. This is especially problematic in loops where storing previous params accumulates memory usage.\n\n- **Donate params to JIT.** Using `donate_argnames=['params']` allows JAX to reuse param buffers in-place rather than allocating new ones. This is a significant memory optimization that explicit ownership enables.\n\nThe tradeoff is intentional: blox gives you flexibility for advanced use cases (branching params for evaluation, keeping param snapshots, using the same params across different model variants) while requiring you to be deliberate about memory. Other libraries manage this automatically but opaquely—you don't know when buffers are shared or copied. With blox, you're in control.", "metadata": {} }, { "cell_type": "markdown", "id": "cell-11", "metadata": {}, "source": [ "---\n", "\n", "# 2. RNG Handling in Parallel Contexts\n", "\n", "## How blox Handles Randomness\n", "\n", "blox uses a seed-and-counter approach for random number generation. The `Rng` module stores a **seed** and a **counter** in params. Each call to `rng(params)` returns a key and increments the counter:\n", "\n", "```python\n", "key = jax.random.fold_in(seed, counter)\n", "counter += 1\n", "```\n", "\n", "This gives deterministic, reproducible randomness while threading state functionally through your computations." ] }, { "cell_type": "markdown", "id": "cell-12", "metadata": {}, "source": [ "## The Problem: Replicated RNG State\n", "\n", "blox recommends **replicating the RNG state** across parallel lanes:\n", "- For `vmap`, we typically replicate all of params (using `in_axes=(None, 0)`).\n", "- For `shard_map`, it depends on sharding, but RNG state should generally be replicated.\n", "\n", "When RNG state is replicated, all lanes share the same seed and counter, so they compute **identical random keys**. This is problematic for operations like dropout where each batch item should receive a unique mask—otherwise all items get dropped in the same positions, defeating the purpose of dropout." ] }, { "cell_type": "code", "execution_count": 11, "id": "cell-13", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Problem: all lanes get the SAME dropout mask!\n", " Lane 0: [1, 1, 0, 0, 0, 0, 0, 0]\n", " Lane 1: [1, 1, 0, 0, 0, 0, 0, 0]\n", " Lane 2: [1, 1, 0, 0, 0, 0, 0, 0]\n", " Lane 3: [1, 1, 0, 0, 0, 0, 0, 0]\n" ] } ], "source": [ "# Setup: create a dropout layer.\n", "graph2 = bx.Graph('rng_demo')\n", "rng2 = bx.Rng(graph2.child('rng'))\n", "dropout = bx.Dropout(graph2.child('dropout'), rate=0.5, rng=rng2)\n", "\n", "# Initialize and lock.\n", "params2 = rng2.seed(bx.Params(), seed=42)\n", "_, params2 = dropout(params2, jnp.ones((1, 8)), is_training=True)\n", "params2 = params2.locked()\n", "\n", "\n", "def apply_dropout_naive(params, x):\n", " \"\"\"Naive approach with no axis folding.\"\"\"\n", " out, params = dropout(params, x, is_training=True)\n", " return out, params\n", "\n", "\n", "# vmap over 4 samples with replicated params.\n", "x_batch = jnp.ones((4, 8))\n", "out_naive, _ = jax.vmap(\n", " apply_dropout_naive,\n", " in_axes=(None, 0), # None means params is replicated.\n", " out_axes=(0, None),\n", ")(params2, x_batch)\n", "\n", "print('Problem: all lanes get the SAME dropout mask!')\n", "for i in range(4):\n", " mask = (out_naive[i] == 0).astype(int).tolist()\n", " print(f' Lane {i}: {mask}')" ] }, { "cell_type": "markdown", "id": "cell-14", "metadata": {}, "source": "## The Solution: Fold in the Lane Index\n\nTo get unique randomness per lane, fold the lane index into the seed before using it. This creates a unique derived seed for each lane while keeping the counter synchronized.\n\n**Approach 2 (using `axis_index`) is recommended** as it's more idiomatic—you don't need to pass extra arguments through your function signatures. However, approach 1 is useful when you need the batch index for other purposes." }, { "cell_type": "code", "execution_count": 12, "id": "cell-15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Solution: each lane gets a UNIQUE dropout mask!\n", " Lane 0: [0, 1, 0, 1, 1, 1, 0, 0]\n", " Lane 1: [0, 1, 1, 0, 0, 0, 1, 1]\n", " Lane 2: [0, 1, 0, 0, 1, 1, 1, 0]\n", " Lane 3: [0, 0, 0, 1, 0, 1, 0, 0]\n" ] } ], "source": [ "# Approach 1: Pass the batch index explicitly.\n", "def apply_dropout_explicit(params, x, batch_idx):\n", " \"\"\"Fold in an explicit batch index for unique randomness.\"\"\"\n", " original_seed = rng2.get_seed(params)\n", " folded_seed = jax.random.fold_in(original_seed, batch_idx)\n", " params = rng2.seed(params, seed=folded_seed)\n", "\n", " out, params = dropout(params, x, is_training=True)\n", "\n", " # Restore original seed so all lanes return identical params.\n", " params = rng2.seed(params, seed=original_seed)\n", " return out, params\n", "\n", "\n", "# Reset counter for fair comparison.\n", "params2 = rng2.seed(params2, counter=0)\n", "\n", "batch_indices = jnp.arange(4)\n", "out_explicit, _ = jax.vmap(\n", " apply_dropout_explicit,\n", " in_axes=(None, 0, 0),\n", " out_axes=(0, None),\n", ")(params2, x_batch, batch_indices)\n", "\n", "print('Solution: each lane gets a UNIQUE dropout mask!')\n", "for i in range(4):\n", " mask = (out_explicit[i] == 0).astype(int).tolist()\n", " print(f' Lane {i}: {mask}')" ] }, { "cell_type": "code", "execution_count": 13, "id": "cell-16", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Same result with axis_index (cleaner when index is not needed elsewhere):\n", " Lane 0: [0, 1, 0, 1, 1, 1, 0, 0]\n", " Lane 1: [0, 1, 1, 0, 0, 0, 1, 1]\n", " Lane 2: [0, 1, 0, 0, 1, 1, 1, 0]\n", " Lane 3: [0, 0, 0, 1, 0, 1, 0, 0]\n" ] } ], "source": [ "# Approach 2: Use jax.lax.axis_index with axis_name.\n", "def apply_dropout_axis_index(params, x):\n", " \"\"\"Use axis_index to get the lane index implicitly.\"\"\"\n", " original_seed = rng2.get_seed(params)\n", " folded_seed = jax.random.fold_in(original_seed, jax.lax.axis_index('batch'))\n", " params = rng2.seed(params, seed=folded_seed)\n", "\n", " out, params = dropout(params, x, is_training=True)\n", "\n", " # Restore original seed.\n", " params = rng2.seed(params, seed=original_seed)\n", " return out, params\n", "\n", "\n", "# Reset counter.\n", "params2 = rng2.seed(params2, counter=0)\n", "\n", "out_axis, _ = jax.vmap(\n", " apply_dropout_axis_index,\n", " in_axes=(None, 0),\n", " out_axes=(0, None),\n", " axis_name='batch', # Required for axis_index to work.\n", ")(params2, x_batch)\n", "\n", "print(\n", " 'Same result with axis_index (cleaner when index is not needed elsewhere):'\n", ")\n", "for i in range(4):\n", " mask = (out_axis[i] == 0).astype(int).tolist()\n", " print(f' Lane {i}: {mask}')" ] }, { "cell_type": "markdown", "id": "cell-17", "metadata": {}, "source": [ "## Why Restore the Original Seed?\n", "\n", "When params are replicated across lanes (`out_axes=None` for params), JAX requires all lanes to return identical pytrees. The counter increments identically in each lane (same operations happen in each), but the seed differs due to folding. By restoring the original seed at the end, we ensure the returned params are identical across all lanes, satisfying JAX's requirement." ] }, { "cell_type": "markdown", "id": "cell-18", "metadata": {}, "source": "---\n\n# 3. Init vs Runtime\n\n## The `is_locked` Pattern\n\nThere's a fundamental distinction between initialization and runtime that affects how RNG should behave:\n\n**Initialization** (`params.is_locked == False`): We're creating parameters. When vmapped, we want **identical params** across lanes. If each lane initialized with different random values, the params wouldn't match and couldn't be merged. So during init, we should NOT fold in the lane index.\n\n**Runtime** (`params.is_locked == True`): We're using existing parameters. When vmapped, we want **unique randomness** per batch item for operations like dropout. So during runtime, we SHOULD fold in the lane index.\n\nUse `params.is_locked` to write functions that behave correctly in both contexts:" }, { "cell_type": "code", "execution_count": 14, "id": "cell-19", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Init complete: is_locked=True\n", "Runtime: unique masks per lane!\n", " Lane 0: [0, 0, 0, 0, 1, 0, 1, 1]\n", " Lane 1: [0, 1, 0, 1, 0, 0, 1, 0]\n", " Lane 2: [0, 0, 0, 1, 0, 0, 1, 0]\n", " Lane 3: [1, 1, 0, 0, 0, 0, 1, 1]\n" ] } ], "source": [ "def forward(params, x):\n", " \"\"\"Works correctly for both init and runtime.\"\"\"\n", " original_seed = rng2.get_seed(params)\n", "\n", " if params.is_locked:\n", " # Runtime: fold in axis index for unique randomness per batch item.\n", " folded_seed = jax.random.fold_in(original_seed, jax.lax.axis_index('batch'))\n", " params = rng2.seed(params, seed=folded_seed)\n", " # Init (unlocked): don't fold, so all lanes create identical params.\n", "\n", " out, params = dropout(params, x, is_training=True)\n", "\n", " # Always restore seed so replicated params stay consistent.\n", " params = rng2.seed(params, seed=original_seed)\n", " return out, params\n", "\n", "\n", "# Init phase: params are unlocked, no folding happens.\n", "def init_fn(x):\n", " p = rng2.seed(bx.Params(), seed=42)\n", " _, p = forward(p, x)\n", " return p.locked()\n", "\n", "\n", "params3 = jax.vmap(init_fn, axis_name='batch', out_axes=None)(x_batch)\n", "print(f'Init complete: is_locked={params3.is_locked}')\n", "\n", "# Runtime phase: params are locked, folding is applied.\n", "out_runtime, _ = jax.vmap(\n", " forward,\n", " in_axes=(None, 0),\n", " out_axes=(0, None),\n", " axis_name='batch',\n", ")(params3, x_batch)\n", "\n", "print('Runtime: unique masks per lane!')\n", "for i in range(4):\n", " mask = (out_runtime[i] == 0).astype(int).tolist()\n", " print(f' Lane {i}: {mask}')" ] }, { "cell_type": "markdown", "id": "cell-20", "metadata": {}, "source": [ "## Sharded Initialization: JIT vs shard_map\n", "\n", "Models can use `jax.jit` or `jax.experimental.shard_map` depending on the level of control you need over device placement. Both can shard initialization and runtime, but they have different tradeoffs.\n", "\n", "**The challenge with `shard_map` for initialization:**\n", "\n", "When using `shard_map` at runtime, sharded params go in and params with the same sharding come out. But how do you *create* those sharded params in the first place? The difficulty is that different parameters may need different sharding specifications, and managing which axes to fold for which params becomes complex.\n", "\n", "**Recommended approach: Initialize with `jax.jit`**\n", "\n", "If you're using `shard_map` at runtime, we recommend initializing params using plain `jax.jit` with `out_shardings`. JIT can:\n", "- Use a replicated RNG to create global (unsharded) parameters.\n", "- Automatically partition them according to the specified `out_shardings`.\n", "\n", "```python\n", "mesh = jax.make_mesh((4,), ('model',))\n", "params_sharding = ... # Build from param metadata.\n", "\n", "@jax.jit(out_shardings=params_sharding)\n", "def init():\n", " params = rng.seed(bx.Params(), seed=42)\n", " _, params = model(params, dummy_input)\n", " return params.locked()\n", "\n", "params = init() # JIT handles sharding automatically.\n", "```\n", "\n", "This approach is simpler and less error-prone than trying to manage sharding manually in `shard_map`." ] }, { "cell_type": "markdown", "id": "cell-21", "metadata": {}, "source": [ "---\n", "\n", "# 4. Graph and Module Modifications\n", "\n", "## The Problem: JIT Caching and Module State\n", "\n", "When you JIT-compile a function, JAX traces it once and caches the compiled result. If you then modify a module's properties, **JAX will not automatically recompile**—it uses the cached trace from before your modification. This leads to silent bugs where your changes appear to have no effect.\n", "\n", "See [JAX Gotchas: Using jax.jit with class methods](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#using-jax-jit-with-class-methods) for the underlying JAX behavior.\n", "\n", "Let's demonstrate with a simple custom module:" ] }, { "cell_type": "code", "execution_count": 15, "id": "cell-22", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "value=10: [1.0, 2.0, 3.0] -> [11.0, 12.0, 13.0]\n", "Changed value to 100.0\n", "After change: [1.0, 2.0, 3.0] -> [11.0, 12.0, 13.0]\n", "BUG: Still adding 10, not 100!\n" ] } ], "source": [ "class AddConstant(bx.Module):\n", " \"\"\"Simple module that adds a constant to the input.\"\"\"\n", "\n", " def __init__(self, graph: bx.Graph, value: float):\n", " super().__init__(graph)\n", " self.value = value # Static config stored as Python float, not JAX array.\n", "\n", " def __call__(self, params: bx.Params, x: jax.Array):\n", " return x + self.value, params\n", "\n", "\n", "graph4 = bx.Graph('modify_demo')\n", "adder = AddConstant(graph4.child('adder'), value=10.0)\n", "params4 = bx.Params().locked()\n", "x = jnp.array([1.0, 2.0, 3.0])\n", "\n", "# JIT-compile adder's __call__ function.\n", "add_constant = jax.jit(adder)\n", "\n", "# First call traces with value=10.\n", "out1, _ = add_constant(params4, x)\n", "print(f'value=10: {x.tolist()} -> {out1.tolist()}')\n", "\n", "# Modify the module.\n", "adder.value = 100.0\n", "print(f'Changed value to {adder.value}')\n", "\n", "# Second call uses CACHED trace—change is NOT visible!\n", "out2, _ = add_constant(params4, x)\n", "print(f'After change: {x.tolist()} -> {out2.tolist()}')\n", "print('BUG: Still adding 10, not 100!')" ] }, { "cell_type": "markdown", "id": "cell-23", "metadata": {}, "source": [ "## blox Philosophy: Prefer Static Models\n", "\n", "JAX's [recommended solution](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#using-jax-jit-with-class-methods) is to register classes as pytrees so that changes trigger recompilation. However, this approach brings complexity when modules contain non-hashable types (like lists or dicts) and requires careful handling of what counts as \"static\" vs \"dynamic\".\n", "\n", "**blox takes a different approach: keep graphs static.**\n", "\n", "Instead of modifying modules after creation, build separate models for different configurations. Since params are decoupled from the graph, multiple models can share the same params—you just need matching parameter paths.\n", "\n", "**Separation of concerns:**\n", "- **Graph** describes *what* operations to perform (static Python objects).\n", "- **Params** holds *what values* to use (dynamic JAX arrays).\n", "\n", "This separation means:\n", "- Graph should not contain JAX arrays in constructor arguments.\n", "- Modules should not create params in `__init__` (wasteful memory, prevents lazy creation).\n", "- Use `module.get_param()` for lazy param creation during the first forward pass." ] }, { "cell_type": "markdown", "id": "cell-24", "metadata": {}, "source": [ "## Pattern 1: Build Separate Static Models\n", "\n", "If you need different configurations, create separate models from the start using a factory function. Each model gets its own JIT cache, avoiding confusion:" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-25", "metadata": {}, "outputs": [], "source": "# Factory function creates fresh models with desired config.\ndef create_adder(value: float):\n graph = bx.Graph('adder')\n return AddConstant(graph.child('add'), value=value)\n\n\nadder_10 = create_adder(10.0)\nadder_100 = create_adder(100.0)\n\n\n# JIT each module with donate_argnames for efficiency.\n@jax.jit(donate_argnames=['params'])\ndef add_10(params, x):\n return adder_10(params, x)\n\n\n@jax.jit(donate_argnames=['params'])\ndef add_100(params, x):\n return adder_100(params, x)\n\n\nout_10, _ = add_10(params4, x)\nout_100, _ = add_100(params4, x)\n\nprint(f'adder_10: {x.tolist()} -> {out_10.tolist()}')\nprint(f'adder_100: {x.tolist()} -> {out_100.tolist()}')\nprint('Both work correctly with static models!')" }, { "cell_type": "markdown", "id": "cell-26", "metadata": {}, "source": "## Pattern 2: Create Separate Functions\n\nIf you need to modify a module, define a new function that captures the modified state. Each new function gets its own JIT cache:" }, { "cell_type": "code", "execution_count": null, "id": "cell-27", "metadata": {}, "outputs": [], "source": [ "# Create a mutable module.\n", "graph5 = bx.Graph('wrap_demo')\n", "mutable_adder = AddConstant(graph5.child('adder'), value=10.0)\n", "\n", "\n", "# Define and JIT a function for value=10.\n", "@jax.jit(donate_argnames=['params'])\n", "def apply_v10(params, x):\n", " return mutable_adder(params, x)\n", "\n", "\n", "out1, _ = apply_v10(params4, x)\n", "print(f'value=10: {x.tolist()} -> {out1.tolist()}')\n", "\n", "# Modify the module.\n", "mutable_adder.value = 100.0\n", "\n", "\n", "# Define and JIT a NEW function for value=100.\n", "@jax.jit(donate_argnames=['params'])\n", "def apply_v100(params, x):\n", " return mutable_adder(params, x)\n", "\n", "\n", "out2, _ = apply_v100(params4, x)\n", "print(f'value=100: {x.tolist()} -> {out2.tolist()}')\n", "print('Each function gets its own JIT cache!')" ] }, { "cell_type": "markdown", "id": "cell-28", "metadata": {}, "source": [ "## Sharing Params Across Models\n", "\n", "Since params are decoupled from the graph, different models can use the same params as long as they have matching parameter paths. This enables powerful patterns:" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-29", "metadata": {}, "outputs": [], "source": [ "# Create two LSTM variants with the same parameter structure.\n", "def create_lstm(is_static: bool):\n", " graph = bx.Graph('model')\n", " rng = bx.Rng(graph.child('rng'))\n", " lstm = bx.LSTM(\n", " graph.child('lstm'), hidden_size=16, rng=rng, is_static=is_static\n", " )\n", " return rng, lstm\n", "\n", "\n", "rng_static, lstm_static = create_lstm(is_static=True) # Uses Python for-loop.\n", "rng_dynamic, lstm_dynamic = create_lstm(is_static=False) # Uses jax.lax.scan.\n", "\n", "# Initialize params using one of them.\n", "x = jnp.ones((2, 5, 8)) # Shape: [batch, time, features].\n", "params5 = rng_static.seed(bx.Params(), seed=42)\n", "state, params5 = lstm_static.initial_state(params5, x[:, 0])\n", "(out, _), params5 = lstm_static.apply(params5, x, prev_state=state)\n", "params5 = params5.locked()\n", "\n", "# Both models work with the SAME params!\n", "out_static, _ = lstm_static.apply(params5, x, prev_state=state)\n", "out_dynamic, _ = lstm_dynamic.apply(params5, x, prev_state=state)\n", "\n", "print(f'Static LSTM output sum: {out_static[0].sum():.4f}')\n", "print(f'Dynamic LSTM output sum: {out_dynamic[0].sum():.4f}')\n", "print('Same params, different execution modes!')" ] }, { "cell_type": "markdown", "id": "cell-30", "metadata": {}, "source": [ "**Use cases for shared params:**\n", "\n", "- **Actor vs Learner in RL**: Separate models for data collection and training that share weights.\n", "- **Training vs Evaluation**: Different logic (e.g., dropout enabled/disabled) while using the same parameters.\n", "- **Debug vs Production**: Python loop for debugging with breakpoints, lax.scan for production speed." ] }, { "cell_type": "markdown", "id": "cell-31", "metadata": {}, "source": "---\n\n# 5. Summary\n\n## Sharp Edges and Patterns\n\n| Issue | Pattern |\n|-------|---------| \n| Same random values in vmap | Fold in lane index; check `is_locked` for init vs runtime. |\n| Sharded param initialization | Use `jax.jit` with `out_shardings`, not `shard_map`. |\n| Creating params during runtime | Lock params after init with `params.locked()`. |\n| Module changes not visible after JIT | Build separate static models, or create fresh JIT wrappers. |\n| Memory waste from old param references | Donate params to JIT; avoid storing references in loops. |\n\n## Key Design Principles\n\n| Principle | Why It Matters |\n|-----------|----------------|\n| Params decoupled from Graph | Multiple models can share same params. |\n| Graph is static | Avoids JIT caching bugs; clear separation of concerns. |\n| Params are immutable | Explicit state changes; works with JAX transforms. |\n| Lazy param creation | No wasted memory; no need to specify shapes upfront. |\n| Locked vs unlocked | Catches bugs; enables init vs runtime RNG behavior. |\n| Explicit params ownership | Full visibility and control; enables buffer donation. |\n\n## References\n\n- [JAX Stateful Computations](https://docs.jax.dev/en/latest/stateful-computations.html)\n- [JAX Common Gotchas](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 5 }