Checkpointable Training on Google Colab¶
This notebook shows how to structure training code that automatically resumes after Colab runtime disconnects.
Prerequisite: This notebook builds on the MNIST Tutorial. If you’re new to blox, start there first.
The Problem¶
Google Colab is amazing for learning and experimentation, but has frustrating limitations:
Idle timeout: Leave your laptop for lunch? Runtime disconnected. Hours of training lost.
Session limits: Long training runs get cut off after ~12 hours (or less on free tier).
No persistent storage: Runtime resets lose all local files.
The pain: You start training, step away, come back to find your session dead and all progress gone.
The Solution¶
Checkpoint everything to Google Drive (persistent storage that survives runtime resets):
Model weights and optimizer state
Dataset iteration position (just an index!)
Training RNG state
When you reconnect and re-run the notebook, training automatically resumes exactly where it left off.
The Functional Approach¶
This implementation demonstrates blox’s philosophy of leaning into the functional paradigm. Just as blox makes all neural network state explicit via outputs, params = model(params, inputs), we make all training state explicit and checkpointable:
JAX’s control of randomness enables deterministic RNG via explicit seeds
TensorFlow’s ``index_shuffle`` computes shuffled indices on-the-fly without materializing arrays
Explicit global index tracks position across all epochs with a single integer
By fully controlling randomness and state, we achieve exact reproducibility: training can be interrupted and resumed at any point with identical results.
How to Use This Notebook¶
First run: Training starts fresh, checkpoints saved periodically to Google Drive.
Colab disconnects mid-training.
Reconnect, click “Run All”.
Training cell detects checkpoint, resumes seamlessly.
Same final result as if never interrupted!
1. Setup¶
[ ]:
# Mount Google Drive for persistent checkpoint storage.
from google.colab import drive
drive.mount('/content/drive')
[ ]:
# Install dependencies.
!pip install -q orbax-checkpoint tensorflow-datasets tensorflow jax-blox optax
[ ]:
import sys
from pathlib import Path
import blox as bx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
import tensorflow as tf
import tensorflow_datasets as tfds
# Disable TensorFlow GPU usage (we use JAX for GPU).
tf.config.set_visible_devices([], 'GPU')
print(f'JAX devices: {jax.devices()}')
2. Configuration¶
[ ]:
# Paths - change EXPERIMENT_NAME for different runs.
EXPERIMENT_NAME = 'mnist_cnn_v1'
CHECKPOINT_DIR = (
Path('/content/drive/MyDrive/blox_checkpoints') / EXPERIMENT_NAME
)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
# Training hyperparameters.
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
MAX_STEPS = 2000
CHECKPOINT_EVERY = 100
# Random seeds.
MODEL_SEED = 42
DATA_SEED = 123
# MNIST data spec (avoids loading dataset just for shape).
DATA_SPEC = {
'image': jnp.zeros((BATCH_SIZE, 28, 28, 1)),
'label': jnp.zeros((BATCH_SIZE,), dtype=jnp.int32),
}
print(f'Checkpoint directory: {CHECKPOINT_DIR}')
3. Utilities¶
Logging and checkpointing helpers. These cells can be collapsed.
[ ]:
# Logging utilities.
class TeeOutput:
"""Redirect stdout/stderr to both console and file."""
def __init__(self, file_path: Path, stream):
self.file = open(file_path, 'a')
self.stream = stream
def write(self, data):
self.file.write(data)
self.stream.write(data)
def flush(self):
self.file.flush()
self.stream.flush()
def setup_logging(checkpoint_dir: Path, resuming: bool):
"""Setup logging to file. Shows previous logs if resuming."""
log_file = checkpoint_dir / 'training.log'
if resuming and log_file.exists():
print('=== Previous training logs ===')
print(log_file.read_text())
print('=== Resuming ===')
sys.stdout = TeeOutput(log_file, sys.__stdout__)
sys.stderr = TeeOutput(log_file, sys.__stderr__)
def flush_logs():
"""Flush logs to disk (call before checkpointing)."""
sys.stdout.flush()
sys.stderr.flush()
[ ]:
# Checkpointing utilities.
def create_checkpoint_manager(checkpoint_dir: Path):
"""Create an Orbax checkpoint manager."""
return ocp.CheckpointManager(
checkpoint_dir,
options=ocp.CheckpointManagerOptions(
max_to_keep=3, save_interval_steps=1
),
)
def save_checkpoint(manager, step, params, opt_state, global_index, data_seed):
"""Save a checkpoint."""
state = {
'params': params.unlocked(),
'opt_state': opt_state,
'step': step,
'global_index': global_index,
'data_seed': data_seed,
}
manager.save(step, args=ocp.args.StandardSave(state))
print(f'Saved checkpoint at step {step}')
def restore_checkpoint(manager, abstract_state):
"""Restore the latest checkpoint."""
step = manager.latest_step()
if step is None:
return None
restored = manager.restore(
step, args=ocp.args.StandardRestore(abstract_state)
)
restored['params'] = restored['params'].locked()
return restored
4. Resumable Dataset¶
Uses tf.random.experimental.index_shuffle for O(1) shuffling and tfds.data_source for high-performance random access.
[ ]:
def create_resumable_mnist(
split: str, batch_size: int, seed: int, start_index: int = 0
):
"""Create a resumable MNIST dataset.
To adapt for other datasets, change Tout and the loading logic.
Args:
split: Dataset split ('train' or 'test').
batch_size: Batch size.
seed: Base shuffle seed.
start_index: Global index to resume from.
Returns:
Dataset yielding {'global_index', 'image', 'label'} batches, dataset_len.
"""
data_source = tfds.data_source('mnist', split=split)
dataset_len = len(data_source)
tf_seed = tf.random.create_rng_state(seed, 'threefry')
@tf.py_function(
Tout={'global_index': tf.int64, 'image': tf.float32, 'label': tf.int32}
)
def get_shuffled_sample(global_idx):
global_idx = global_idx.numpy()
epoch = global_idx // dataset_len
idx_in_epoch = global_idx % dataset_len
epoch_seed = tf.random.fold_in(tf_seed, epoch)
# max_index is inclusive, so use dataset_len - 1.
shuffled_idx = tf.random.experimental.index_shuffle(
idx_in_epoch, epoch_seed, dataset_len - 1
)
record = data_source[shuffled_idx]
image = record['image'].astype(np.float32) / 255.0
label = record['label']
return {'global_index': global_idx, 'image': image, 'label': label}
ds = tf.data.Dataset.range(start_index, start_index + dataset_len * 100)
ds = ds.map(get_shuffled_sample, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.prefetch(tf.data.AUTOTUNE)
return ds, dataset_len
5. Model¶
[ ]:
class SimpleCNN(bx.Module):
"""Simple CNN for MNIST."""
def __init__(self, graph: bx.Graph, rng: bx.Rng):
super().__init__(graph)
self.conv1 = bx.Conv(
graph.child('conv1'), output_channels=32, kernel_size=(3, 3), rng=rng
)
self.conv2 = bx.Conv(
graph.child('conv2'), output_channels=64, kernel_size=(3, 3), rng=rng
)
self.linear1 = bx.Linear(graph.child('linear1'), output_size=128, rng=rng)
self.linear2 = bx.Linear(graph.child('linear2'), output_size=10, rng=rng)
def __call__(self, params: bx.Params, x: jax.Array):
x, params = self.conv1(params, x)
x = jax.nn.relu(x)
x = bx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x, params = self.conv2(params, x)
x = jax.nn.relu(x)
x = bx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape(x.shape[0], -1)
x, params = self.linear1(params, x)
x = jax.nn.relu(x)
x, params = self.linear2(params, x)
return x, params
6. Training Step¶
[ ]:
def create_train_step(model, optimizer):
"""Create a JIT-compiled training step function."""
@jax.jit(donate_argnames=['params', 'opt_state'])
def train_step(params, opt_state, batch):
images, labels = batch['image'], batch['label']
trainable, non_trainable = params.split()
def loss_fn(t, nt):
logits, new_params = model(t.merge(nt), images)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits, labels
).mean()
_, new_nt = new_params.split()
return loss, (new_nt, logits)
grads, (new_non_trainable, logits) = jax.grad(loss_fn, has_aux=True)(
trainable, non_trainable
)
updates, new_opt_state = optimizer.update(grads, opt_state, trainable)
new_trainable = optax.apply_updates(trainable, updates)
new_params = new_trainable.merge(new_non_trainable)
accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
return new_params, new_opt_state, accuracy
return train_step
7. Initialize and Restore¶
[ ]:
# Create model and optimizer.
graph = bx.Graph('model')
rng = bx.Rng(graph.child('rng'))
model = SimpleCNN(graph.child('cnn'), rng=rng)
optimizer = optax.adam(learning_rate=LEARNING_RATE)
# Initialize params using data spec.
params = rng.seed(bx.Params(), seed=MODEL_SEED)
_, params = model(params, DATA_SPEC['image'])
params = params.locked()
opt_state = optimizer.init(params)
# Try to restore from checkpoint.
ckpt_manager = create_checkpoint_manager(CHECKPOINT_DIR)
abstract_state = {
'params': params.unlocked(),
'opt_state': opt_state,
'step': 0,
'global_index': 0,
'data_seed': DATA_SEED,
}
restored = restore_checkpoint(ckpt_manager, abstract_state)
if restored is not None:
params = restored['params']
opt_state = restored['opt_state']
step = restored['step']
global_index = restored['global_index']
data_seed = restored['data_seed']
resuming = True
print(f'Resumed from step {step}, global_index {global_index}')
else:
step = 0
global_index = 0
data_seed = DATA_SEED
resuming = False
print('Starting fresh')
# Setup logging and dataset.
setup_logging(CHECKPOINT_DIR, resuming)
train_ds, dataset_len = create_resumable_mnist(
split='train',
batch_size=BATCH_SIZE,
seed=data_seed,
start_index=global_index,
)
train_step = create_train_step(model, optimizer)
print(f'Training from step {step} to {MAX_STEPS}...')
8. Training Loop¶
[ ]:
for batch in train_ds:
if step >= MAX_STEPS:
break
params, opt_state, accuracy = train_step(params, opt_state, batch)
step += 1
global_index = int(batch['global_index'][-1].numpy()) + 1
if step % 50 == 0:
epoch = global_index // dataset_len
print(f'Step {step}, epoch {epoch}, accuracy: {accuracy:.4f}')
if step % CHECKPOINT_EVERY == 0:
flush_logs()
save_checkpoint(
ckpt_manager, step, params, opt_state, global_index, data_seed
)
print(f'Training complete! Final step: {step}')
flush_logs()
save_checkpoint(ckpt_manager, step, params, opt_state, global_index, data_seed)
9. Evaluation¶
[ ]:
# Evaluation uses standard tf.data batching (no checkpointable logic needed).
test_ds = tfds.load('mnist', split='test')
test_ds = test_ds.batch(BATCH_SIZE, drop_remainder=True)
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
@jax.jit
def evaluate_batch(params, images, labels):
"""Evaluate a batch and return correctness mask."""
logits, _ = model(params, images)
predictions = jnp.argmax(logits, axis=-1)
return predictions == labels
correct = 0
total = 0
for batch in test_ds:
images = batch['image'].numpy().astype(np.float32) / 255.0
labels = batch['label'].numpy()
is_correct = evaluate_batch(params, images, labels)
correct += is_correct.sum()
total += len(labels)
accuracy = correct / total
print(f'Test accuracy: {accuracy:.4f}')
10. Summary¶
What We Checkpoint¶
State |
Why |
|---|---|
|
Model weights and blox RNG state |
|
Optimizer momentum, learning rate schedules |
|
Current training step |
|
Position across all epochs |
|
Base seed for reproducible shuffling |
Key Patterns¶
O(1) shuffling with ``index_shuffle``: Computes shuffled indices on-the-fly without materializing arrays.
Random access with ``tfds.data_source``: Access any sample by index.
Global index: A single integer tracks position across all epochs.
Google Drive storage: Survives runtime resets.