LoRA (Low-Rank Adaptation)¶
This notebook demonstrates how to implement LoRA fine-tuning with blox.
LoRA is a parameter-efficient fine-tuning technique that freezes the pretrained model weights and injects trainable low-rank decomposition matrices. Instead of fine-tuning W, we compute W + A @ B where A and B are small matrices.
Key benefits:
Memory efficient: Only LoRA parameters need gradients
Modular: Can add/remove adapters without changing base model
Stackable: Multiple adapters can be combined
Reference: LoRA: Low-Rank Adaptation of Large Language Models
Two Approaches¶
There are multiple ways to add LoRA to a model with blox:
Monkey-patching (shown first): Take an existing pretrained model and wrap its layers with LoRA at runtime. This is useful when you have a pretrained model and want to add adapters without modifying the original code.
LoRA-aware design (shown later): Design your model to support LoRA from the start by using wrapper layers. This is cleaner and more explicit, but requires planning ahead.
Both approaches are valid - choose based on whether you’re adapting an existing model or building a new one.
[1]:
import sys
sys.path.insert(0, '../src')
import blox as bx
import jax
import jax.numpy as jnp
Approach 1: Monkey-patching an Existing Model¶
This approach is useful when you have a pretrained model and want to add LoRA adapters without modifying the original model code.
1. Define an MLP Module¶
First, let’s create a reusable MLP module that we’ll later adapt with LoRA.
[2]:
class MLP(bx.Module):
"""Multi-layer perceptron with configurable layers and activation."""
def __init__(
self,
graph: bx.Graph,
output_sizes: list[int],
rng: bx.Rng,
activation=jax.nn.relu,
):
super().__init__(graph)
self.activation = activation
self.layers = []
for i, size in enumerate(output_sizes):
name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'
self.layers.append(bx.Linear(graph.child(name), size, rng=rng))
def __call__(self, params, x):
for i, layer in enumerate(self.layers):
x, params = layer(params, x)
# Apply activation to all but the last layer.
if i < len(self.layers) - 1:
x = self.activation(x)
return x, params
print('MLP module defined!')
MLP module defined!
[3]:
# Create the model.
graph = bx.Graph('net')
rng = bx.Rng(graph.child('rng'))
model = MLP(
graph.child('mlp'),
output_sizes=[64, 32, 10],
rng=rng,
activation=jax.nn.relu,
)
# Initialize the model.
x = jnp.ones((4, 16))
params = rng.seed(bx.Params(), seed=42)
_, params = model(params, x)
params = params.locked()
print('Model initialized!')
print(f'Total parameters: {len(params)}')
Model initialized!
Total parameters: 8
2. Explore the Graph¶
blox provides Graph.walk() to iterate over all modules in the graph. This is useful for finding which layers to apply LoRA to.
[ ]:
# Walk the graph to see all modules.
print('Modules in graph:')
for path, module in graph.walk():
print(f' {path}: {module!r}')
Modules in graph:
('net', 'rng'): Rng(auto_fold_in_axes=True)
('net', 'mlp'): MLP(output_sizes=[64, 32, 10], rng=Rng(auto_fold_in_axes=True), activation=<jax._src.custom_derivatives.custom_jvp object at 0x7aade2b336b0>)
('net', 'mlp', 'hidden0'): Linear(output_size=64, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x7aade29d76a0>, bias_init=<function zeros at 0x7aadf813ade0>, kernel_metadata=None, bias_metadata=None)
('net', 'mlp', 'hidden1'): Linear(output_size=32, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x7aade29d76a0>, bias_init=<function zeros at 0x7aadf813ade0>, kernel_metadata=None, bias_metadata=None)
('net', 'mlp', 'output'): Linear(output_size=10, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=<function variance_scaling.<locals>.init at 0x7aade29d76a0>, bias_init=<function zeros at 0x7aadf813ade0>, kernel_metadata=None, bias_metadata=None)
3. The LoRA Pattern¶
LoRA replaces y = x @ W with y = x @ W + x @ A @ B where:
Wis frozen (original weights)Ahas shape(in_features, rank)- initialized with GaussianBhas shape(rank, out_features)- initialized to zeros
Initialization: Following the original paper, A is initialized with a Gaussian distribution and B is initialized to zeros. This ensures A @ B = 0 at the start, so the model begins with its original pretrained behavior.
Scaling: The output is scaled by alpha / rank to stabilize training across different rank values.
[5]:
def apply_lora(layer: bx.Linear, rank: int = 4, alpha: float | None = None):
"""Applies LoRA to a Linear layer by monkey-patching get_param.
Args:
layer: The Linear layer to adapt.
rank: The rank of the low-rank matrices.
alpha: Scaling factor. Defaults to 2 * rank (common heuristic).
"""
if alpha is None:
alpha = 2 * rank
# Save the original get_param.
original_get_param = layer.get_param
def lora_get_param(params, name, shape=None, init=None, rng=None, **kwargs):
# Get the original parameter.
value, params = original_get_param(
params, name, shape, init, rng=rng, **kwargs
)
# Only apply LoRA to the kernel.
if name != 'kernel':
return value, params
in_features, out_features = value.shape
# LoRA initialization (from the original paper):
# - A: Gaussian with std = 1/sqrt(rank) for stable gradients.
# - B: Zeros, so LoRA starts as identity (A @ 0 = 0).
lora_a, params = original_get_param(
params,
'lora_a',
(in_features, rank),
jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(rank)),
rng=rng,
)
lora_b, params = original_get_param(
params,
'lora_b',
(rank, out_features),
jax.nn.initializers.zeros,
rng=rng,
)
# Compute W + (alpha / rank) * A @ B.
scale = alpha / rank
adapted_kernel = value + scale * (lora_a @ lora_b)
return adapted_kernel, params
# Replace get_param with our LoRA version.
layer.get_param = lora_get_param
# Store config for merge/removal.
layer._lora_original_get_param = original_get_param
def remove_lora(layer: bx.Linear):
"""Removes LoRA from a layer, restoring original behavior."""
if hasattr(layer, '_lora_original_get_param'):
layer.get_param = layer._lora_original_get_param
del layer._lora_original_get_param
print('LoRA functions defined!')
LoRA functions defined!
4. Apply LoRA to Selected Layers¶
Let’s apply LoRA to the hidden layers but not the output layer.
[6]:
# Apply LoRA to all Linear layers except the output layer.
for path, module in graph.walk():
if isinstance(module, bx.Linear) and path[-1] != 'output':
apply_lora(module, rank=4)
print(f'Applied LoRA to {path}')
Applied LoRA to ('net', 'mlp', 'hidden0')
Applied LoRA to ('net', 'mlp', 'hidden1')
5. Initialize LoRA Parameters¶
Now we need to run a forward pass to create the LoRA parameters. Since params are locked, we first unlock them.
[7]:
# Unlock params to allow new parameters.
params = params.unlocked()
# Run forward pass to initialize LoRA params.
_, params = model(params, x)
# Lock again.
params = params.locked()
print(f'Total parameters after LoRA: {len(params)}')
print('\nNew LoRA parameters:')
for path, param in params.items():
if 'lora' in path[-1]:
print(f' {path}: {param.value.shape}')
Total parameters after LoRA: 12
New LoRA parameters:
('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)
('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)
('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)
('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)
6. Freeze Base Weights¶
For LoRA training, we want to freeze the original weights and only train the LoRA parameters.
[8]:
def freeze_base_weights(params, layers):
"""Freezes all non-LoRA parameters in the given layers."""
for layer in layers:
# Freeze kernel and bias.
params = layer.set_param(params, 'kernel', None, trainable=False)
if layer.use_bias:
params = layer.set_param(params, 'bias', None, trainable=False)
return params
# Freeze all base weights.
params = freeze_base_weights(params, model.layers)
print('Base weights frozen!')
print('\nTrainable parameters:')
for path, param in params.items():
if param.trainable:
print(f' {path}: {param.value.shape}')
Base weights frozen!
Trainable parameters:
('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)
('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)
('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)
('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)
7. Training with LoRA¶
Now we can train! Only LoRA parameters will receive gradients.
[9]:
# Generate some dummy training data.
x_train = jax.random.normal(jax.random.key(0), (32, 16))
y_train = jax.random.normal(jax.random.key(1), (32, 10))
# Save params before training for comparison later.
params_before_training = jax.tree.map(lambda x: x.copy(), params)
@jax.jit(donate_argnames='params')
def train_step(params, x, y):
# Split into trainable (LoRA) and non-trainable (base + RNG).
trainable, non_trainable = params.split()
def loss_fn(t, nt):
full_params = t.merge(nt)
pred, new_params = model(full_params, x)
_, new_nt = new_params.split()
return jnp.mean((pred - y) ** 2), new_nt
grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(
trainable, non_trainable
)
# SGD update.
new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)
return new_trainable.merge(new_non_trainable)
# Compute initial loss.
pred, _ = model(params, x_train)
initial_loss = jnp.mean((pred - y_train) ** 2)
print(f'Initial loss: {initial_loss:.4f}')
# Train for a few steps.
for step in range(100):
params = train_step(params, x_train, y_train)
# Compute final loss.
pred, _ = model(params, x_train)
final_loss = jnp.mean((pred - y_train) ** 2)
print(f'Final loss: {final_loss:.4f}')
print(f'Loss reduced by {(1 - final_loss/initial_loss) * 100:.1f}%')
Initial loss: 1.2383
Final loss: 0.9382
Loss reduced by 24.2%
8. Verify Only LoRA Weights Changed¶
Let’s verify that only the LoRA parameters were updated during training.
[10]:
# Compare base weights before and after training.
print('Base weights unchanged:')
for path, param in params_before_training.items():
if 'lora' not in path[-1]:
match = jnp.allclose(param.value, params[path].value)
print(f' {path}: {match}')
Base weights unchanged:
('net', 'mlp', 'hidden0', 'bias'): True
('net', 'mlp', 'hidden0', 'kernel'): True
('net', 'mlp', 'hidden1', 'bias'): True
('net', 'mlp', 'hidden1', 'kernel'): True
('net', 'mlp', 'output', 'bias'): True
('net', 'mlp', 'output', 'kernel'): True
('net', 'rng', 'counter'): True
('net', 'rng', 'seed'): True
9. Merging LoRA Weights (Optional)¶
For inference, you can merge LoRA weights into the base weights to avoid the extra computation.
[ ]:
def merge_lora_weights(params, layer):
"""Merges LoRA weights into base weights for efficient inference."""
if not hasattr(layer, '_lora_original_get_param'):
return params
# get_param returns the adapted kernel (W + A @ B) while LoRA is active.
merged_kernel, _ = layer.get_param(params, 'kernel')
# Update the base kernel with the merged value.
return layer.set_param(params, 'kernel', merged_kernel)
# Store output before merging.
out_with_lora, _ = model(params, x_train[:1])
# Merge LoRA weights into base kernels.
merged_params = params
for layer in model.layers:
merged_params = merge_lora_weights(merged_params, layer)
# Remove LoRA from layers.
for layer in model.layers:
remove_lora(layer)
# Remove LoRA params from the container.
def is_lora_param(path, param):
return 'lora' in path[-1]
lora_params, merged_params = merged_params.split(is_lora_param)
# Test that output is the same.
out_merged, _ = model(merged_params, x_train[:1])
outputs_match = jnp.allclose(out_with_lora, out_merged, atol=1e-5)
print(f'Outputs match after merge: {outputs_match}')
print(f'Parameters after merge: {len(merged_params)} (was {len(params)})')
Outputs match after merge: True
Parameters after merge: 8 (was 12)
Approach 2: LoRA-aware Model Design¶
Instead of monkey-patching an existing model, you can design your model to support LoRA from the start. This approach is cleaner and more explicit.
The idea is to create a LoraLinear wrapper that:
Wraps a standard
bx.LinearlayerAdds LoRA parameters alongside the base parameters
Can be enabled/disabled via a flag
[ ]:
class LoraLinear(bx.Module):
"""A Linear layer with optional LoRA adaptation.
This wrapper creates both the base Linear layer and LoRA parameters.
LoRA can be enabled/disabled without changing the model structure.
"""
def __init__(
self,
graph: bx.Graph,
output_size: int,
rng: bx.Rng,
rank: int = 4,
alpha: float | None = None,
use_lora: bool = True,
):
super().__init__(graph)
self.rank = rank
self.alpha = alpha if alpha is not None else 2 * rank
self.use_lora = use_lora
self.rng = rng
# Create the base Linear layer.
self.linear = bx.Linear(graph.child('base'), output_size, rng=rng)
def __call__(self, params, x):
# Get base output.
out, params = self.linear(params, x)
if not self.use_lora:
return out, params
# Get or create LoRA parameters.
in_features = x.shape[-1]
out_features = self.linear.output_size
lora_a, params = self.get_param(
params,
'lora_a',
(in_features, self.rank),
jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(self.rank)),
rng=self.rng,
)
lora_b, params = self.get_param(
params,
'lora_b',
(self.rank, out_features),
jax.nn.initializers.zeros,
rng=self.rng,
)
# Add LoRA contribution: scale * x @ A @ B
scale = self.alpha / self.rank
out = out + scale * (x @ lora_a @ lora_b)
return out, params
print('LoraLinear module defined!')
Create a LoRA-aware MLP¶
Now we can create an MLP that uses LoraLinear layers. The model supports LoRA from the start - no monkey-patching needed.
[ ]:
class LoraAwareMLP(bx.Module):
"""MLP that supports LoRA from the start."""
def __init__(
self,
graph: bx.Graph,
output_sizes: list[int],
rng: bx.Rng,
lora_rank: int = 4,
use_lora: bool = True,
activation=jax.nn.relu,
):
super().__init__(graph)
self.activation = activation
self.layers = []
for i, size in enumerate(output_sizes):
name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'
# Use LoraLinear for hidden layers, regular Linear for output.
if i < len(output_sizes) - 1:
layer = LoraLinear(
graph.child(name), size, rng=rng, rank=lora_rank, use_lora=use_lora
)
else:
layer = bx.Linear(graph.child(name), size, rng=rng)
self.layers.append(layer)
def __call__(self, params, x):
for i, layer in enumerate(self.layers):
x, params = layer(params, x)
if i < len(self.layers) - 1:
x = self.activation(x)
return x, params
# Create and initialize the LoRA-aware model.
graph2 = bx.Graph('lora_net')
rng2 = bx.Rng(graph2.child('rng'))
lora_model = LoraAwareMLP(
graph2.child('mlp'),
output_sizes=[64, 32, 10],
rng=rng2,
lora_rank=4,
use_lora=True,
)
# Initialize.
x2 = jnp.ones((4, 16))
params2 = rng2.seed(bx.Params(), seed=42)
_, params2 = lora_model(params2, x2)
params2 = params2.locked()
print('LoRA-aware model initialized!')
print(f'Total parameters: {len(params2)}')
print('\nParameter structure:')
for path, param in params2.items():
if 'lora' in path[-1] or 'kernel' in path[-1]:
print(f' {path}: {param.value.shape}')
Benefits of the LoRA-aware Approach¶
Notice how the LoRA parameters are created alongside the base parameters during initialization - no need to unlock/relock params or run an extra forward pass.
To fine-tune, we still freeze the base weights and train only the LoRA parameters:
[ ]:
# Freeze base weights (kernel and bias in the nested 'base' Linear layers).
for layer in lora_model.layers:
if isinstance(layer, LoraLinear):
params2 = layer.linear.set_param(params2, 'kernel', None, trainable=False)
params2 = layer.linear.set_param(params2, 'bias', None, trainable=False)
elif isinstance(layer, bx.Linear):
params2 = layer.set_param(params2, 'kernel', None, trainable=False)
params2 = layer.set_param(params2, 'bias', None, trainable=False)
print('Trainable parameters (LoRA only):')
for path, param in params2.items():
if param.trainable:
print(f' {path}: {param.value.shape}')
Toggling LoRA¶
With the LoRA-aware design, you can easily disable LoRA by setting use_lora=False on each layer. This is useful for comparing adapted vs original model behavior:
[ ]:
# Output with LoRA enabled.
out_with_lora, _ = lora_model(params2, x2[:1])
# Disable LoRA on all layers.
for layer in lora_model.layers:
if isinstance(layer, LoraLinear):
layer.use_lora = False
# Output without LoRA (base model only).
out_without_lora, _ = lora_model(params2, x2[:1])
# Re-enable LoRA.
for layer in lora_model.layers:
if isinstance(layer, LoraLinear):
layer.use_lora = True
print('Output with LoRA:', out_with_lora[0, :5])
print('Output without LoRA:', out_without_lora[0, :5])
print(f'\nOutputs differ: {not jnp.allclose(out_with_lora, out_without_lora)}')
Summary¶
This notebook demonstrated two approaches to implementing LoRA with blox:
Best for adapting existing pretrained models without modifying their code.
Graph traversal - Used
Graph.walk()to find target layersMonkey-patching - Wrapped
get_paramto injectW + A @ BParameter freezing - Marked base weights as non-trainable
Weight merging - Combined LoRA weights back into base weights for inference
Best for new models where you plan to use LoRA from the start.
Wrapper layer - Created
LoraLinearthat encapsulates base + LoRA parametersClean initialization - LoRA parameters created alongside base parameters
Easy toggling - Enable/disable LoRA via
use_loraflagExplicit structure - Model structure shows LoRA is supported
Scenario |
Recommended Approach |
|---|---|
Adapting a pretrained model |
Monkey-patching |
Building a new model with LoRA support |
LoRA-aware design |
Research/experimentation |
Either (monkey-patching is more flexible) |
Production deployment |
LoRA-aware (cleaner, more maintainable) |