Open in Colab View on GitHub

Checkpointable Training on Google Colab

This notebook shows how to structure training code that automatically resumes after Colab runtime disconnects.

Prerequisite: This notebook builds on the MNIST Tutorial. If you’re new to blox, start there first.

The Problem

Google Colab is amazing for learning and experimentation, but has frustrating limitations:

  1. Idle timeout: Leave your laptop for lunch? Runtime disconnected. Hours of training lost.

  2. Session limits: Long training runs get cut off after ~12 hours (or less on free tier).

  3. No persistent storage: Runtime resets lose all local files.

The pain: You start training, step away, come back to find your session dead and all progress gone.

The Solution

Checkpoint everything to Google Drive (persistent storage that survives runtime resets):

  • Model weights and optimizer state

  • Dataset iteration position (just an index!)

  • Training RNG state

When you reconnect and re-run the notebook, training automatically resumes exactly where it left off.

The Functional Approach

This implementation demonstrates blox’s philosophy of leaning into the functional paradigm. Just as blox makes all neural network state explicit via outputs, params = model(params, inputs), we make all training state explicit and checkpointable:

  • JAX’s control of randomness enables deterministic RNG via explicit seeds

  • TensorFlow’s ``index_shuffle`` computes shuffled indices on-the-fly without materializing arrays

  • Explicit global index tracks position across all epochs with a single integer

By fully controlling randomness and state, we achieve exact reproducibility: training can be interrupted and resumed at any point with identical results.

How to Use This Notebook

  1. First run: Training starts fresh, checkpoints saved periodically to Google Drive.

  2. Colab disconnects mid-training.

  3. Reconnect, click “Run All”.

  4. Training cell detects checkpoint, resumes seamlessly.

  5. Same final result as if never interrupted!


1. Setup

[ ]:
# Mount Google Drive for persistent checkpoint storage.
from google.colab import drive

drive.mount('/content/drive')
[ ]:
# Install dependencies.
!pip install -q orbax-checkpoint tensorflow-datasets tensorflow jax-blox optax
[ ]:
import sys
from pathlib import Path

import blox as bx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
import tensorflow as tf
import tensorflow_datasets as tfds

# Disable TensorFlow GPU usage (we use JAX for GPU).
tf.config.set_visible_devices([], 'GPU')

print(f'JAX devices: {jax.devices()}')

2. Configuration

[ ]:
# Paths - change EXPERIMENT_NAME for different runs.
EXPERIMENT_NAME = 'mnist_cnn_v1'
CHECKPOINT_DIR = (
    Path('/content/drive/MyDrive/blox_checkpoints') / EXPERIMENT_NAME
)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# Training hyperparameters.
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
MAX_STEPS = 2000
CHECKPOINT_EVERY = 100

# Random seeds.
MODEL_SEED = 42
DATA_SEED = 123

# MNIST data spec (avoids loading dataset just for shape).
DATA_SPEC = {
    'image': jnp.zeros((BATCH_SIZE, 28, 28, 1)),
    'label': jnp.zeros((BATCH_SIZE,), dtype=jnp.int32),
}

print(f'Checkpoint directory: {CHECKPOINT_DIR}')

3. Utilities

Logging and checkpointing helpers. These cells can be collapsed.

[ ]:
# Logging utilities.
class TeeOutput:
  """Redirect stdout/stderr to both console and file."""

  def __init__(self, file_path: Path, stream):
    self.file = open(file_path, 'a')
    self.stream = stream

  def write(self, data):
    self.file.write(data)
    self.stream.write(data)

  def flush(self):
    self.file.flush()
    self.stream.flush()


def setup_logging(checkpoint_dir: Path, resuming: bool):
  """Setup logging to file. Shows previous logs if resuming."""
  log_file = checkpoint_dir / 'training.log'
  if resuming and log_file.exists():
    print('=== Previous training logs ===')
    print(log_file.read_text())
    print('=== Resuming ===')
  sys.stdout = TeeOutput(log_file, sys.__stdout__)
  sys.stderr = TeeOutput(log_file, sys.__stderr__)


def flush_logs():
  """Flush logs to disk (call before checkpointing)."""
  sys.stdout.flush()
  sys.stderr.flush()
[ ]:
# Checkpointing utilities.
def create_checkpoint_manager(checkpoint_dir: Path):
  """Create an Orbax checkpoint manager."""
  return ocp.CheckpointManager(
      checkpoint_dir,
      options=ocp.CheckpointManagerOptions(
          max_to_keep=3, save_interval_steps=1
      ),
  )


def save_checkpoint(manager, step, params, opt_state, global_index, data_seed):
  """Save a checkpoint."""
  state = {
      'params': params.unlocked(),
      'opt_state': opt_state,
      'step': step,
      'global_index': global_index,
      'data_seed': data_seed,
  }
  manager.save(step, args=ocp.args.StandardSave(state))
  print(f'Saved checkpoint at step {step}')


def restore_checkpoint(manager, abstract_state):
  """Restore the latest checkpoint."""
  step = manager.latest_step()
  if step is None:
    return None
  restored = manager.restore(
      step, args=ocp.args.StandardRestore(abstract_state)
  )
  restored['params'] = restored['params'].locked()
  return restored

4. Resumable Dataset

Uses tf.random.experimental.index_shuffle for O(1) shuffling and tfds.data_source for high-performance random access.

[ ]:
def create_resumable_mnist(
    split: str, batch_size: int, seed: int, start_index: int = 0
):
  """Create a resumable MNIST dataset.

  To adapt for other datasets, change Tout and the loading logic.

  Args:
    split: Dataset split ('train' or 'test').
    batch_size: Batch size.
    seed: Base shuffle seed.
    start_index: Global index to resume from.

  Returns:
    Dataset yielding {'global_index', 'image', 'label'} batches, dataset_len.
  """
  data_source = tfds.data_source('mnist', split=split)
  dataset_len = len(data_source)
  tf_seed = tf.random.create_rng_state(seed, 'threefry')

  @tf.py_function(
      Tout={'global_index': tf.int64, 'image': tf.float32, 'label': tf.int32}
  )
  def get_shuffled_sample(global_idx):
    global_idx = global_idx.numpy()
    epoch = global_idx // dataset_len
    idx_in_epoch = global_idx % dataset_len

    epoch_seed = tf.random.fold_in(tf_seed, epoch)
    # max_index is inclusive, so use dataset_len - 1.
    shuffled_idx = tf.random.experimental.index_shuffle(
        idx_in_epoch, epoch_seed, dataset_len - 1
    )

    record = data_source[shuffled_idx]
    image = record['image'].astype(np.float32) / 255.0
    label = record['label']

    return {'global_index': global_idx, 'image': image, 'label': label}

  ds = tf.data.Dataset.range(start_index, start_index + dataset_len * 100)
  ds = ds.map(get_shuffled_sample, num_parallel_calls=tf.data.AUTOTUNE)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)

  return ds, dataset_len

5. Model

[ ]:
class SimpleCNN(bx.Module):
  """Simple CNN for MNIST."""

  def __init__(self, graph: bx.Graph, rng: bx.Rng):
    super().__init__(graph)
    self.conv1 = bx.Conv(
        graph.child('conv1'), output_channels=32, kernel_size=(3, 3), rng=rng
    )
    self.conv2 = bx.Conv(
        graph.child('conv2'), output_channels=64, kernel_size=(3, 3), rng=rng
    )
    self.linear1 = bx.Linear(graph.child('linear1'), output_size=128, rng=rng)
    self.linear2 = bx.Linear(graph.child('linear2'), output_size=10, rng=rng)

  def __call__(self, params: bx.Params, x: jax.Array):
    x, params = self.conv1(params, x)
    x = jax.nn.relu(x)
    x = bx.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x, params = self.conv2(params, x)
    x = jax.nn.relu(x)
    x = bx.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = x.reshape(x.shape[0], -1)
    x, params = self.linear1(params, x)
    x = jax.nn.relu(x)
    x, params = self.linear2(params, x)

    return x, params

6. Training Step

[ ]:
def create_train_step(model, optimizer):
  """Create a JIT-compiled training step function."""

  @jax.jit(donate_argnames=['params', 'opt_state'])
  def train_step(params, opt_state, batch):
    images, labels = batch['image'], batch['label']
    trainable, non_trainable = params.split()

    def loss_fn(t, nt):
      logits, new_params = model(t.merge(nt), images)
      loss = optax.softmax_cross_entropy_with_integer_labels(
          logits, labels
      ).mean()
      _, new_nt = new_params.split()
      return loss, (new_nt, logits)

    grads, (new_non_trainable, logits) = jax.grad(loss_fn, has_aux=True)(
        trainable, non_trainable
    )

    updates, new_opt_state = optimizer.update(grads, opt_state, trainable)
    new_trainable = optax.apply_updates(trainable, updates)
    new_params = new_trainable.merge(new_non_trainable)

    accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
    return new_params, new_opt_state, accuracy

  return train_step

7. Initialize and Restore

[ ]:
# Create model and optimizer.
graph = bx.Graph('model')
rng = bx.Rng(graph.child('rng'))
model = SimpleCNN(graph.child('cnn'), rng=rng)
optimizer = optax.adam(learning_rate=LEARNING_RATE)

# Initialize params using data spec.
params = rng.seed(bx.Params(), seed=MODEL_SEED)
_, params = model(params, DATA_SPEC['image'])
params = params.locked()
opt_state = optimizer.init(params)

# Try to restore from checkpoint.
ckpt_manager = create_checkpoint_manager(CHECKPOINT_DIR)
abstract_state = {
    'params': params.unlocked(),
    'opt_state': opt_state,
    'step': 0,
    'global_index': 0,
    'data_seed': DATA_SEED,
}
restored = restore_checkpoint(ckpt_manager, abstract_state)

if restored is not None:
  params = restored['params']
  opt_state = restored['opt_state']
  step = restored['step']
  global_index = restored['global_index']
  data_seed = restored['data_seed']
  resuming = True
  print(f'Resumed from step {step}, global_index {global_index}')
else:
  step = 0
  global_index = 0
  data_seed = DATA_SEED
  resuming = False
  print('Starting fresh')

# Setup logging and dataset.
setup_logging(CHECKPOINT_DIR, resuming)
train_ds, dataset_len = create_resumable_mnist(
    split='train',
    batch_size=BATCH_SIZE,
    seed=data_seed,
    start_index=global_index,
)
train_step = create_train_step(model, optimizer)

print(f'Training from step {step} to {MAX_STEPS}...')

8. Training Loop

[ ]:
for batch in train_ds:
  if step >= MAX_STEPS:
    break

  params, opt_state, accuracy = train_step(params, opt_state, batch)
  step += 1
  global_index = int(batch['global_index'][-1].numpy()) + 1

  if step % 50 == 0:
    epoch = global_index // dataset_len
    print(f'Step {step}, epoch {epoch}, accuracy: {accuracy:.4f}')

  if step % CHECKPOINT_EVERY == 0:
    flush_logs()
    save_checkpoint(
        ckpt_manager, step, params, opt_state, global_index, data_seed
    )

print(f'Training complete! Final step: {step}')
flush_logs()
save_checkpoint(ckpt_manager, step, params, opt_state, global_index, data_seed)

9. Evaluation

[ ]:
# Evaluation uses standard tf.data batching (no checkpointable logic needed).
test_ds = tfds.load('mnist', split='test')
test_ds = test_ds.batch(BATCH_SIZE, drop_remainder=True)
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)


@jax.jit
def evaluate_batch(params, images, labels):
  """Evaluate a batch and return correctness mask."""
  logits, _ = model(params, images)
  predictions = jnp.argmax(logits, axis=-1)
  return predictions == labels


correct = 0
total = 0
for batch in test_ds:
  images = batch['image'].numpy().astype(np.float32) / 255.0
  labels = batch['label'].numpy()
  is_correct = evaluate_batch(params, images, labels)
  correct += is_correct.sum()
  total += len(labels)

accuracy = correct / total
print(f'Test accuracy: {accuracy:.4f}')

10. Summary

What We Checkpoint

State

Why

params

Model weights and blox RNG state

opt_state

Optimizer momentum, learning rate schedules

step

Current training step

global_index

Position across all epochs

data_seed

Base seed for reproducible shuffling

Key Patterns

  1. O(1) shuffling with ``index_shuffle``: Computes shuffled indices on-the-fly without materializing arrays.

  2. Random access with ``tfds.data_source``: Access any sample by index.

  3. Global index: A single integer tracks position across all epochs.

  4. Google Drive storage: Survives runtime resets.

References