Open in Colab View on GitHub

LoRA (Low-Rank Adaptation)

This notebook demonstrates how to implement LoRA fine-tuning with blox.

LoRA is a parameter-efficient fine-tuning technique that freezes the pretrained model weights and injects trainable low-rank decomposition matrices. Instead of fine-tuning W, we compute W + A @ B where A and B are small matrices.

Key benefits:

  • Memory efficient: Only LoRA parameters need gradients

  • Modular: Can add/remove adapters without changing base model

  • Stackable: Multiple adapters can be combined

Reference: LoRA: Low-Rank Adaptation of Large Language Models

Two Approaches

There are multiple ways to add LoRA to a model with blox:

  1. Monkey-patching (shown first): Take an existing pretrained model and wrap its layers with LoRA at runtime. This is useful when you have a pretrained model and want to add adapters without modifying the original code.

  2. LoRA-aware design (shown later): Design your model to support LoRA from the start by using wrapper layers. This is cleaner and more explicit, but requires planning ahead.

Both approaches are valid - choose based on whether you’re adapting an existing model or building a new one.

[1]:
import sys

sys.path.insert(0, '../src')

import blox as bx
import jax
import jax.numpy as jnp

Approach 1: Monkey-patching an Existing Model

This approach is useful when you have a pretrained model and want to add LoRA adapters without modifying the original model code.

1. Define an MLP Module

First, let’s create a reusable MLP module that we’ll later adapt with LoRA.

[2]:
class MLP(bx.Module):
  """Multi-layer perceptron with configurable layers and activation."""

  def __init__(
      self,
      graph: bx.Graph,
      output_sizes: list[int],
      rng: bx.Rng,
      activation=jax.nn.relu,
  ):
    super().__init__(graph)
    self.activation = activation
    self.layers = []
    for i, size in enumerate(output_sizes):
      name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'
      self.layers.append(bx.Linear(graph.child(name), size, rng=rng))

  def __call__(self, params, x):
    for i, layer in enumerate(self.layers):
      x, params = layer(params, x)
      # Apply activation to all but the last layer.
      if i < len(self.layers) - 1:
        x = self.activation(x)
    return x, params


print('MLP module defined!')
MLP module defined!
[3]:
# Create the model.
graph = bx.Graph('net')
rng = bx.Rng(graph.child('rng'))

model = MLP(
    graph.child('mlp'),
    output_sizes=[64, 32, 10],
    rng=rng,
    activation=jax.nn.relu,
)

# Initialize the model.
x = jnp.ones((4, 16))
params = rng.seed(bx.Params(), seed=42)
_, params = model(params, x)
params = params.locked()

print('Model initialized!')
print(f'Total parameters: {len(params)}')
Model initialized!
Total parameters: 8

2. Explore the Graph

blox provides Graph.walk() to iterate over all modules in the graph. This is useful for finding which layers to apply LoRA to.

[ ]:
# Walk the graph to see all modules.
print('Modules in graph:')
for path, module in graph.walk():
  print(f'  {path}: {module!r}')
Modules in graph:
  ('net', 'rng'): Rng(auto_fold_in_axes=True)
  ('net', 'mlp'): MLP(output_sizes=[64, 32, 10], rng=Rng(auto_fold_in_axes=True), activation=<jax._src.custom_derivatives.custom_jvp object at 0x7aade2b336b0>)
  ('net', 'mlp', 'hidden0'): Linear(output_size=64, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x7aade29d76a0>, bias_init=<function zeros at 0x7aadf813ade0>, kernel_metadata=None, bias_metadata=None)
  ('net', 'mlp', 'hidden1'): Linear(output_size=32, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x7aade29d76a0>, bias_init=<function zeros at 0x7aadf813ade0>, kernel_metadata=None, bias_metadata=None)
  ('net', 'mlp', 'output'): Linear(output_size=10, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x7aade29d76a0>, bias_init=<function zeros at 0x7aadf813ade0>, kernel_metadata=None, bias_metadata=None)

3. The LoRA Pattern

LoRA replaces y = x @ W with y = x @ W + x @ A @ B where:

  • W is frozen (original weights)

  • A has shape (in_features, rank) - initialized with Gaussian

  • B has shape (rank, out_features) - initialized to zeros

Initialization: Following the original paper, A is initialized with a Gaussian distribution and B is initialized to zeros. This ensures A @ B = 0 at the start, so the model begins with its original pretrained behavior.

Scaling: The output is scaled by alpha / rank to stabilize training across different rank values.

[5]:
def apply_lora(layer: bx.Linear, rank: int = 4, alpha: float | None = None):
  """Applies LoRA to a Linear layer by monkey-patching get_param.

  Args:
    layer: The Linear layer to adapt.
    rank: The rank of the low-rank matrices.
    alpha: Scaling factor. Defaults to 2 * rank (common heuristic).
  """
  if alpha is None:
    alpha = 2 * rank

  # Save the original get_param.
  original_get_param = layer.get_param

  def lora_get_param(params, name, shape=None, init=None, rng=None, **kwargs):
    # Get the original parameter.
    value, params = original_get_param(
        params, name, shape, init, rng=rng, **kwargs
    )

    # Only apply LoRA to the kernel.
    if name != 'kernel':
      return value, params

    in_features, out_features = value.shape

    # LoRA initialization (from the original paper):
    # - A: Gaussian with std = 1/sqrt(rank) for stable gradients.
    # - B: Zeros, so LoRA starts as identity (A @ 0 = 0).
    lora_a, params = original_get_param(
        params,
        'lora_a',
        (in_features, rank),
        jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(rank)),
        rng=rng,
    )
    lora_b, params = original_get_param(
        params,
        'lora_b',
        (rank, out_features),
        jax.nn.initializers.zeros,
        rng=rng,
    )

    # Compute W + (alpha / rank) * A @ B.
    scale = alpha / rank
    adapted_kernel = value + scale * (lora_a @ lora_b)

    return adapted_kernel, params

  # Replace get_param with our LoRA version.
  layer.get_param = lora_get_param

  # Store config for merge/removal.
  layer._lora_original_get_param = original_get_param


def remove_lora(layer: bx.Linear):
  """Removes LoRA from a layer, restoring original behavior."""
  if hasattr(layer, '_lora_original_get_param'):
    layer.get_param = layer._lora_original_get_param
    del layer._lora_original_get_param


print('LoRA functions defined!')
LoRA functions defined!

4. Apply LoRA to Selected Layers

Let’s apply LoRA to the hidden layers but not the output layer.

[6]:
# Apply LoRA to all Linear layers except the output layer.
for path, module in graph.walk():
  if isinstance(module, bx.Linear) and path[-1] != 'output':
    apply_lora(module, rank=4)
    print(f'Applied LoRA to {path}')
Applied LoRA to ('net', 'mlp', 'hidden0')
Applied LoRA to ('net', 'mlp', 'hidden1')

5. Initialize LoRA Parameters

Now we need to run a forward pass to create the LoRA parameters. Since params are locked, we first unlock them.

[7]:
# Unlock params to allow new parameters.
params = params.unlocked()

# Run forward pass to initialize LoRA params.
_, params = model(params, x)

# Lock again.
params = params.locked()

print(f'Total parameters after LoRA: {len(params)}')
print('\nNew LoRA parameters:')
for path, param in params.items():
  if 'lora' in path[-1]:
    print(f'  {path}: {param.value.shape}')
Total parameters after LoRA: 12

New LoRA parameters:
  ('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)
  ('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)
  ('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)
  ('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)

6. Freeze Base Weights

For LoRA training, we want to freeze the original weights and only train the LoRA parameters.

[8]:
def freeze_base_weights(params, layers):
  """Freezes all non-LoRA parameters in the given layers."""
  for layer in layers:
    # Freeze kernel and bias.
    params = layer.set_param(params, 'kernel', None, trainable=False)
    if layer.use_bias:
      params = layer.set_param(params, 'bias', None, trainable=False)
  return params


# Freeze all base weights.
params = freeze_base_weights(params, model.layers)

print('Base weights frozen!')
print('\nTrainable parameters:')
for path, param in params.items():
  if param.trainable:
    print(f'  {path}: {param.value.shape}')
Base weights frozen!

Trainable parameters:
  ('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)
  ('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)
  ('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)
  ('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)

7. Training with LoRA

Now we can train! Only LoRA parameters will receive gradients.

[9]:
# Generate some dummy training data.
x_train = jax.random.normal(jax.random.key(0), (32, 16))
y_train = jax.random.normal(jax.random.key(1), (32, 10))

# Save params before training for comparison later.
params_before_training = jax.tree.map(lambda x: x.copy(), params)


@jax.jit(donate_argnames='params')
def train_step(params, x, y):
  # Split into trainable (LoRA) and non-trainable (base + RNG).
  trainable, non_trainable = params.split()

  def loss_fn(t, nt):
    full_params = t.merge(nt)
    pred, new_params = model(full_params, x)
    _, new_nt = new_params.split()
    return jnp.mean((pred - y) ** 2), new_nt

  grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(
      trainable, non_trainable
  )

  # SGD update.
  new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)

  return new_trainable.merge(new_non_trainable)


# Compute initial loss.
pred, _ = model(params, x_train)
initial_loss = jnp.mean((pred - y_train) ** 2)
print(f'Initial loss: {initial_loss:.4f}')

# Train for a few steps.
for step in range(100):
  params = train_step(params, x_train, y_train)

# Compute final loss.
pred, _ = model(params, x_train)
final_loss = jnp.mean((pred - y_train) ** 2)
print(f'Final loss: {final_loss:.4f}')
print(f'Loss reduced by {(1 - final_loss/initial_loss) * 100:.1f}%')
Initial loss: 1.2383
Final loss: 0.9382
Loss reduced by 24.2%

8. Verify Only LoRA Weights Changed

Let’s verify that only the LoRA parameters were updated during training.

[10]:
# Compare base weights before and after training.
print('Base weights unchanged:')
for path, param in params_before_training.items():
  if 'lora' not in path[-1]:
    match = jnp.allclose(param.value, params[path].value)
    print(f'  {path}: {match}')
Base weights unchanged:
  ('net', 'mlp', 'hidden0', 'bias'): True
  ('net', 'mlp', 'hidden0', 'kernel'): True
  ('net', 'mlp', 'hidden1', 'bias'): True
  ('net', 'mlp', 'hidden1', 'kernel'): True
  ('net', 'mlp', 'output', 'bias'): True
  ('net', 'mlp', 'output', 'kernel'): True
  ('net', 'rng', 'counter'): True
  ('net', 'rng', 'seed'): True

9. Merging LoRA Weights (Optional)

For inference, you can merge LoRA weights into the base weights to avoid the extra computation.

[ ]:
def merge_lora_weights(params, layer):
  """Merges LoRA weights into base weights for efficient inference."""
  if not hasattr(layer, '_lora_original_get_param'):
    return params

  # get_param returns the adapted kernel (W + A @ B) while LoRA is active.
  merged_kernel, _ = layer.get_param(params, 'kernel')

  # Update the base kernel with the merged value.
  return layer.set_param(params, 'kernel', merged_kernel)


# Store output before merging.
out_with_lora, _ = model(params, x_train[:1])

# Merge LoRA weights into base kernels.
merged_params = params
for layer in model.layers:
  merged_params = merge_lora_weights(merged_params, layer)

# Remove LoRA from layers.
for layer in model.layers:
  remove_lora(layer)


# Remove LoRA params from the container.
def is_lora_param(path, param):
  return 'lora' in path[-1]


lora_params, merged_params = merged_params.split(is_lora_param)

# Test that output is the same.
out_merged, _ = model(merged_params, x_train[:1])
outputs_match = jnp.allclose(out_with_lora, out_merged, atol=1e-5)
print(f'Outputs match after merge: {outputs_match}')
print(f'Parameters after merge: {len(merged_params)} (was {len(params)})')
Outputs match after merge: True
Parameters after merge: 8 (was 12)

Approach 2: LoRA-aware Model Design

Instead of monkey-patching an existing model, you can design your model to support LoRA from the start. This approach is cleaner and more explicit.

The idea is to create a LoraLinear wrapper that:

  • Wraps a standard bx.Linear layer

  • Adds LoRA parameters alongside the base parameters

  • Can be enabled/disabled via a flag

[ ]:
class LoraLinear(bx.Module):
  """A Linear layer with optional LoRA adaptation.

  This wrapper creates both the base Linear layer and LoRA parameters.
  LoRA can be enabled/disabled without changing the model structure.
  """

  def __init__(
      self,
      graph: bx.Graph,
      output_size: int,
      rng: bx.Rng,
      rank: int = 4,
      alpha: float | None = None,
      use_lora: bool = True,
  ):
    super().__init__(graph)
    self.rank = rank
    self.alpha = alpha if alpha is not None else 2 * rank
    self.use_lora = use_lora
    self.rng = rng

    # Create the base Linear layer.
    self.linear = bx.Linear(graph.child('base'), output_size, rng=rng)

  def __call__(self, params, x):
    # Get base output.
    out, params = self.linear(params, x)

    if not self.use_lora:
      return out, params

    # Get or create LoRA parameters.
    in_features = x.shape[-1]
    out_features = self.linear.output_size

    lora_a, params = self.get_param(
        params,
        'lora_a',
        (in_features, self.rank),
        jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(self.rank)),
        rng=self.rng,
    )
    lora_b, params = self.get_param(
        params,
        'lora_b',
        (self.rank, out_features),
        jax.nn.initializers.zeros,
        rng=self.rng,
    )

    # Add LoRA contribution: scale * x @ A @ B
    scale = self.alpha / self.rank
    out = out + scale * (x @ lora_a @ lora_b)

    return out, params


print('LoraLinear module defined!')

Create a LoRA-aware MLP

Now we can create an MLP that uses LoraLinear layers. The model supports LoRA from the start - no monkey-patching needed.

[ ]:
class LoraAwareMLP(bx.Module):
  """MLP that supports LoRA from the start."""

  def __init__(
      self,
      graph: bx.Graph,
      output_sizes: list[int],
      rng: bx.Rng,
      lora_rank: int = 4,
      use_lora: bool = True,
      activation=jax.nn.relu,
  ):
    super().__init__(graph)
    self.activation = activation
    self.layers = []
    for i, size in enumerate(output_sizes):
      name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'
      # Use LoraLinear for hidden layers, regular Linear for output.
      if i < len(output_sizes) - 1:
        layer = LoraLinear(
            graph.child(name), size, rng=rng, rank=lora_rank, use_lora=use_lora
        )
      else:
        layer = bx.Linear(graph.child(name), size, rng=rng)
      self.layers.append(layer)

  def __call__(self, params, x):
    for i, layer in enumerate(self.layers):
      x, params = layer(params, x)
      if i < len(self.layers) - 1:
        x = self.activation(x)
    return x, params


# Create and initialize the LoRA-aware model.
graph2 = bx.Graph('lora_net')
rng2 = bx.Rng(graph2.child('rng'))

lora_model = LoraAwareMLP(
    graph2.child('mlp'),
    output_sizes=[64, 32, 10],
    rng=rng2,
    lora_rank=4,
    use_lora=True,
)

# Initialize.
x2 = jnp.ones((4, 16))
params2 = rng2.seed(bx.Params(), seed=42)
_, params2 = lora_model(params2, x2)
params2 = params2.locked()

print('LoRA-aware model initialized!')
print(f'Total parameters: {len(params2)}')
print('\nParameter structure:')
for path, param in params2.items():
  if 'lora' in path[-1] or 'kernel' in path[-1]:
    print(f'  {path}: {param.value.shape}')

Benefits of the LoRA-aware Approach

Notice how the LoRA parameters are created alongside the base parameters during initialization - no need to unlock/relock params or run an extra forward pass.

To fine-tune, we still freeze the base weights and train only the LoRA parameters:

[ ]:
# Freeze base weights (kernel and bias in the nested 'base' Linear layers).
for layer in lora_model.layers:
  if isinstance(layer, LoraLinear):
    params2 = layer.linear.set_param(params2, 'kernel', None, trainable=False)
    params2 = layer.linear.set_param(params2, 'bias', None, trainable=False)
  elif isinstance(layer, bx.Linear):
    params2 = layer.set_param(params2, 'kernel', None, trainable=False)
    params2 = layer.set_param(params2, 'bias', None, trainable=False)

print('Trainable parameters (LoRA only):')
for path, param in params2.items():
  if param.trainable:
    print(f'  {path}: {param.value.shape}')

Toggling LoRA

With the LoRA-aware design, you can easily disable LoRA by setting use_lora=False on each layer. This is useful for comparing adapted vs original model behavior:

[ ]:
# Output with LoRA enabled.
out_with_lora, _ = lora_model(params2, x2[:1])

# Disable LoRA on all layers.
for layer in lora_model.layers:
  if isinstance(layer, LoraLinear):
    layer.use_lora = False

# Output without LoRA (base model only).
out_without_lora, _ = lora_model(params2, x2[:1])

# Re-enable LoRA.
for layer in lora_model.layers:
  if isinstance(layer, LoraLinear):
    layer.use_lora = True

print('Output with LoRA:', out_with_lora[0, :5])
print('Output without LoRA:', out_without_lora[0, :5])
print(f'\nOutputs differ: {not jnp.allclose(out_with_lora, out_without_lora)}')

Summary

This notebook demonstrated two approaches to implementing LoRA with blox:

Best for adapting existing pretrained models without modifying their code.

  1. Graph traversal - Used Graph.walk() to find target layers

  2. Monkey-patching - Wrapped get_param to inject W + A @ B

  3. Parameter freezing - Marked base weights as non-trainable

  4. Weight merging - Combined LoRA weights back into base weights for inference

Best for new models where you plan to use LoRA from the start.

  1. Wrapper layer - Created LoraLinear that encapsulates base + LoRA parameters

  2. Clean initialization - LoRA parameters created alongside base parameters

  3. Easy toggling - Enable/disable LoRA via use_lora flag

  4. Explicit structure - Model structure shows LoRA is supported

Scenario

Recommended Approach

Adapting a pretrained model

Monkey-patching

Building a new model with LoRA support

LoRA-aware design

Research/experimentation

Either (monkey-patching is more flexible)

Production deployment

LoRA-aware (cleaner, more maintainable)