{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": "# Checkpointable Training on Google Colab\n\nThis notebook shows how to structure training code that **automatically resumes** after Colab runtime disconnects.\n\n> **Prerequisite**: This notebook builds on the [MNIST Tutorial](mnist_tutorial.ipynb). If you're new to blox, start there first.\n\n## The Problem\n\nGoogle Colab is amazing for learning and experimentation, but has frustrating limitations:\n\n1. **Idle timeout**: Leave your laptop for lunch? Runtime disconnected. Hours of training lost.\n2. **Session limits**: Long training runs get cut off after ~12 hours (or less on free tier).\n3. **No persistent storage**: Runtime resets lose all local files.\n\n**The pain**: You start training, step away, come back to find your session dead and all progress gone.\n\n## The Solution\n\nCheckpoint everything to **Google Drive** (persistent storage that survives runtime resets):\n- Model weights and optimizer state\n- Dataset iteration position (just an index!)\n- Training RNG state\n\nWhen you reconnect and re-run the notebook, training **automatically resumes exactly where it left off**.\n\n## The Functional Approach\n\nThis implementation demonstrates blox's philosophy of **leaning into the functional paradigm**. Just as blox makes all neural network state explicit via `outputs, params = model(params, inputs)`, we make all training state explicit and checkpointable:\n\n- **JAX's control of randomness** enables deterministic RNG via explicit seeds\n- **TensorFlow's `index_shuffle`** computes shuffled indices on-the-fly without materializing arrays\n- **Explicit global index** tracks position across all epochs with a single integer\n\nBy fully controlling randomness and state, we achieve exact reproducibility: training can be interrupted and resumed at any point with identical results.\n\n## How to Use This Notebook\n\n1. First run: Training starts fresh, checkpoints saved periodically to Google Drive.\n2. Colab disconnects mid-training.\n3. Reconnect, click \"Run All\".\n4. Training cell detects checkpoint, resumes seamlessly.\n5. Same final result as if never interrupted!" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 1. Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Mount Google Drive for persistent checkpoint storage.\n", "from google.colab import drive\n", "\n", "drive.mount('/content/drive')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install dependencies.\n", "!pip install -q orbax-checkpoint tensorflow-datasets tensorflow jax-blox optax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "\n", "import blox as bx\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import optax\n", "import orbax.checkpoint as ocp\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "# Disable TensorFlow GPU usage (we use JAX for GPU).\n", "tf.config.set_visible_devices([], 'GPU')\n", "\n", "print(f'JAX devices: {jax.devices()}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 2. Configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Paths - change EXPERIMENT_NAME for different runs.\n", "EXPERIMENT_NAME = 'mnist_cnn_v1'\n", "CHECKPOINT_DIR = (\n", " Path('/content/drive/MyDrive/blox_checkpoints') / EXPERIMENT_NAME\n", ")\n", "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", "# Training hyperparameters.\n", "BATCH_SIZE = 128\n", "LEARNING_RATE = 1e-3\n", "MAX_STEPS = 2000\n", "CHECKPOINT_EVERY = 100\n", "\n", "# Random seeds.\n", "MODEL_SEED = 42\n", "DATA_SEED = 123\n", "\n", "# MNIST data spec (avoids loading dataset just for shape).\n", "DATA_SPEC = {\n", " 'image': jnp.zeros((BATCH_SIZE, 28, 28, 1)),\n", " 'label': jnp.zeros((BATCH_SIZE,), dtype=jnp.int32),\n", "}\n", "\n", "print(f'Checkpoint directory: {CHECKPOINT_DIR}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 3. Utilities\n", "\n", "Logging and checkpointing helpers. These cells can be collapsed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Logging utilities.\n", "class TeeOutput:\n", " \"\"\"Redirect stdout/stderr to both console and file.\"\"\"\n", "\n", " def __init__(self, file_path: Path, stream):\n", " self.file = open(file_path, 'a')\n", " self.stream = stream\n", "\n", " def write(self, data):\n", " self.file.write(data)\n", " self.stream.write(data)\n", "\n", " def flush(self):\n", " self.file.flush()\n", " self.stream.flush()\n", "\n", "\n", "def setup_logging(checkpoint_dir: Path, resuming: bool):\n", " \"\"\"Setup logging to file. Shows previous logs if resuming.\"\"\"\n", " log_file = checkpoint_dir / 'training.log'\n", " if resuming and log_file.exists():\n", " print('=== Previous training logs ===')\n", " print(log_file.read_text())\n", " print('=== Resuming ===')\n", " sys.stdout = TeeOutput(log_file, sys.__stdout__)\n", " sys.stderr = TeeOutput(log_file, sys.__stderr__)\n", "\n", "\n", "def flush_logs():\n", " \"\"\"Flush logs to disk (call before checkpointing).\"\"\"\n", " sys.stdout.flush()\n", " sys.stderr.flush()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Checkpointing utilities.\n", "def create_checkpoint_manager(checkpoint_dir: Path):\n", " \"\"\"Create an Orbax checkpoint manager.\"\"\"\n", " return ocp.CheckpointManager(\n", " checkpoint_dir,\n", " options=ocp.CheckpointManagerOptions(\n", " max_to_keep=3, save_interval_steps=1\n", " ),\n", " )\n", "\n", "\n", "def save_checkpoint(manager, step, params, opt_state, global_index, data_seed):\n", " \"\"\"Save a checkpoint.\"\"\"\n", " state = {\n", " 'params': params.unlocked(),\n", " 'opt_state': opt_state,\n", " 'step': step,\n", " 'global_index': global_index,\n", " 'data_seed': data_seed,\n", " }\n", " manager.save(step, args=ocp.args.StandardSave(state))\n", " print(f'Saved checkpoint at step {step}')\n", "\n", "\n", "def restore_checkpoint(manager, abstract_state):\n", " \"\"\"Restore the latest checkpoint.\"\"\"\n", " step = manager.latest_step()\n", " if step is None:\n", " return None\n", " restored = manager.restore(\n", " step, args=ocp.args.StandardRestore(abstract_state)\n", " )\n", " restored['params'] = restored['params'].locked()\n", " return restored" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 4. Resumable Dataset\n", "\n", "Uses `tf.random.experimental.index_shuffle` for O(1) shuffling and `tfds.data_source` for high-performance random access." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_resumable_mnist(\n", " split: str, batch_size: int, seed: int, start_index: int = 0\n", "):\n", " \"\"\"Create a resumable MNIST dataset.\n", "\n", " To adapt for other datasets, change Tout and the loading logic.\n", "\n", " Args:\n", " split: Dataset split ('train' or 'test').\n", " batch_size: Batch size.\n", " seed: Base shuffle seed.\n", " start_index: Global index to resume from.\n", "\n", " Returns:\n", " Dataset yielding {'global_index', 'image', 'label'} batches, dataset_len.\n", " \"\"\"\n", " data_source = tfds.data_source('mnist', split=split)\n", " dataset_len = len(data_source)\n", " tf_seed = tf.random.create_rng_state(seed, 'threefry')\n", "\n", " @tf.py_function(\n", " Tout={'global_index': tf.int64, 'image': tf.float32, 'label': tf.int32}\n", " )\n", " def get_shuffled_sample(global_idx):\n", " global_idx = global_idx.numpy()\n", " epoch = global_idx // dataset_len\n", " idx_in_epoch = global_idx % dataset_len\n", "\n", " epoch_seed = tf.random.fold_in(tf_seed, epoch)\n", " # max_index is inclusive, so use dataset_len - 1.\n", " shuffled_idx = tf.random.experimental.index_shuffle(\n", " idx_in_epoch, epoch_seed, dataset_len - 1\n", " )\n", "\n", " record = data_source[shuffled_idx]\n", " image = record['image'].astype(np.float32) / 255.0\n", " label = record['label']\n", "\n", " return {'global_index': global_idx, 'image': image, 'label': label}\n", "\n", " ds = tf.data.Dataset.range(start_index, start_index + dataset_len * 100)\n", " ds = ds.map(get_shuffled_sample, num_parallel_calls=tf.data.AUTOTUNE)\n", " ds = ds.batch(batch_size)\n", " ds = ds.prefetch(tf.data.AUTOTUNE)\n", "\n", " return ds, dataset_len" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 5. Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SimpleCNN(bx.Module):\n", " \"\"\"Simple CNN for MNIST.\"\"\"\n", "\n", " def __init__(self, graph: bx.Graph, rng: bx.Rng):\n", " super().__init__(graph)\n", " self.conv1 = bx.Conv(\n", " graph.child('conv1'), output_channels=32, kernel_size=(3, 3), rng=rng\n", " )\n", " self.conv2 = bx.Conv(\n", " graph.child('conv2'), output_channels=64, kernel_size=(3, 3), rng=rng\n", " )\n", " self.linear1 = bx.Linear(graph.child('linear1'), output_size=128, rng=rng)\n", " self.linear2 = bx.Linear(graph.child('linear2'), output_size=10, rng=rng)\n", "\n", " def __call__(self, params: bx.Params, x: jax.Array):\n", " x, params = self.conv1(params, x)\n", " x = jax.nn.relu(x)\n", " x = bx.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n", "\n", " x, params = self.conv2(params, x)\n", " x = jax.nn.relu(x)\n", " x = bx.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n", "\n", " x = x.reshape(x.shape[0], -1)\n", " x, params = self.linear1(params, x)\n", " x = jax.nn.relu(x)\n", " x, params = self.linear2(params, x)\n", "\n", " return x, params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 6. Training Step" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_train_step(model, optimizer):\n", " \"\"\"Create a JIT-compiled training step function.\"\"\"\n", "\n", " @jax.jit(donate_argnames=['params', 'opt_state'])\n", " def train_step(params, opt_state, batch):\n", " images, labels = batch['image'], batch['label']\n", " trainable, non_trainable = params.split()\n", "\n", " def loss_fn(t, nt):\n", " logits, new_params = model(t.merge(nt), images)\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits, labels\n", " ).mean()\n", " _, new_nt = new_params.split()\n", " return loss, (new_nt, logits)\n", "\n", " grads, (new_non_trainable, logits) = jax.grad(loss_fn, has_aux=True)(\n", " trainable, non_trainable\n", " )\n", "\n", " updates, new_opt_state = optimizer.update(grads, opt_state, trainable)\n", " new_trainable = optax.apply_updates(trainable, updates)\n", " new_params = new_trainable.merge(new_non_trainable)\n", "\n", " accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()\n", " return new_params, new_opt_state, accuracy\n", "\n", " return train_step" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 7. Initialize and Restore" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model and optimizer.\n", "graph = bx.Graph('model')\n", "rng = bx.Rng(graph.child('rng'))\n", "model = SimpleCNN(graph.child('cnn'), rng=rng)\n", "optimizer = optax.adam(learning_rate=LEARNING_RATE)\n", "\n", "# Initialize params using data spec.\n", "params = rng.seed(bx.Params(), seed=MODEL_SEED)\n", "_, params = model(params, DATA_SPEC['image'])\n", "params = params.locked()\n", "opt_state = optimizer.init(params)\n", "\n", "# Try to restore from checkpoint.\n", "ckpt_manager = create_checkpoint_manager(CHECKPOINT_DIR)\n", "abstract_state = {\n", " 'params': params.unlocked(),\n", " 'opt_state': opt_state,\n", " 'step': 0,\n", " 'global_index': 0,\n", " 'data_seed': DATA_SEED,\n", "}\n", "restored = restore_checkpoint(ckpt_manager, abstract_state)\n", "\n", "if restored is not None:\n", " params = restored['params']\n", " opt_state = restored['opt_state']\n", " step = restored['step']\n", " global_index = restored['global_index']\n", " data_seed = restored['data_seed']\n", " resuming = True\n", " print(f'Resumed from step {step}, global_index {global_index}')\n", "else:\n", " step = 0\n", " global_index = 0\n", " data_seed = DATA_SEED\n", " resuming = False\n", " print('Starting fresh')\n", "\n", "# Setup logging and dataset.\n", "setup_logging(CHECKPOINT_DIR, resuming)\n", "train_ds, dataset_len = create_resumable_mnist(\n", " split='train',\n", " batch_size=BATCH_SIZE,\n", " seed=data_seed,\n", " start_index=global_index,\n", ")\n", "train_step = create_train_step(model, optimizer)\n", "\n", "print(f'Training from step {step} to {MAX_STEPS}...')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 8. Training Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for batch in train_ds:\n", " if step >= MAX_STEPS:\n", " break\n", "\n", " params, opt_state, accuracy = train_step(params, opt_state, batch)\n", " step += 1\n", " global_index = int(batch['global_index'][-1].numpy()) + 1\n", "\n", " if step % 50 == 0:\n", " epoch = global_index // dataset_len\n", " print(f'Step {step}, epoch {epoch}, accuracy: {accuracy:.4f}')\n", "\n", " if step % CHECKPOINT_EVERY == 0:\n", " flush_logs()\n", " save_checkpoint(\n", " ckpt_manager, step, params, opt_state, global_index, data_seed\n", " )\n", "\n", "print(f'Training complete! Final step: {step}')\n", "flush_logs()\n", "save_checkpoint(ckpt_manager, step, params, opt_state, global_index, data_seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 9. Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Evaluation uses standard tf.data batching (no checkpointable logic needed).\n", "test_ds = tfds.load('mnist', split='test')\n", "test_ds = test_ds.batch(BATCH_SIZE, drop_remainder=True)\n", "test_ds = test_ds.prefetch(tf.data.AUTOTUNE)\n", "\n", "\n", "@jax.jit\n", "def evaluate_batch(params, images, labels):\n", " \"\"\"Evaluate a batch and return correctness mask.\"\"\"\n", " logits, _ = model(params, images)\n", " predictions = jnp.argmax(logits, axis=-1)\n", " return predictions == labels\n", "\n", "\n", "correct = 0\n", "total = 0\n", "for batch in test_ds:\n", " images = batch['image'].numpy().astype(np.float32) / 255.0\n", " labels = batch['label'].numpy()\n", " is_correct = evaluate_batch(params, images, labels)\n", " correct += is_correct.sum()\n", " total += len(labels)\n", "\n", "accuracy = correct / total\n", "print(f'Test accuracy: {accuracy:.4f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# 10. Summary\n", "\n", "## What We Checkpoint\n", "\n", "| State | Why |\n", "|-------|-----|\n", "| `params` | Model weights and blox RNG state |\n", "| `opt_state` | Optimizer momentum, learning rate schedules |\n", "| `step` | Current training step |\n", "| `global_index` | Position across all epochs |\n", "| `data_seed` | Base seed for reproducible shuffling |\n", "\n", "## Key Patterns\n", "\n", "1. **O(1) shuffling with `index_shuffle`**: Computes shuffled indices on-the-fly without materializing arrays.\n", "2. **Random access with `tfds.data_source`**: Access any sample by index.\n", "3. **Global index**: A single integer tracks position across all epochs.\n", "4. **Google Drive storage**: Survives runtime resets.\n", "\n", "## References\n", "\n", "- [tf.random.experimental.index_shuffle](https://www.tensorflow.org/api_docs/python/tf/random/experimental/index_shuffle)\n", "- [Orbax Documentation](https://orbax.readthedocs.io/)\n", "- [TFDS Documentation](https://www.tensorflow.org/datasets)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }