{
"cells": [
{
"cell_type": "markdown",
"id": "cell-0",
"metadata": {},
"source": "markdown# LoRA (Low-Rank Adaptation) with blox\n\nThis notebook demonstrates how to implement LoRA fine-tuning with blox.\n\nLoRA is a parameter-efficient fine-tuning technique that freezes the pretrained model weights and injects trainable low-rank decomposition matrices. Instead of fine-tuning `W`, we compute `W + A @ B` where `A` and `B` are small matrices.\n\nKey benefits:\n- **Memory efficient**: Only LoRA parameters need gradients\n- **Modular**: Can add/remove adapters without changing base model\n- **Stackable**: Multiple adapters can be combined\n\nReference: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)\n\n## Two Approaches\n\nThere are multiple ways to add LoRA to a model with blox:\n\n1. **Monkey-patching** (shown first): Take an existing pretrained model and wrap its layers with LoRA at runtime. This is useful when you have a pretrained model and want to add adapters without modifying the original code.\n\n2. **LoRA-aware design** (shown later): Design your model to support LoRA from the start by using wrapper layers. This is cleaner and more explicit, but requires planning ahead.\n\nBoth approaches are valid - choose based on whether you're adapting an existing model or building a new one."
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cell-1",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.insert(0, '../src')\n",
"\n",
"import blox as bx\n",
"import jax\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "markdown",
"id": "cell-2",
"metadata": {},
"source": "markdown---\n\n# Approach 1: Monkey-patching an Existing Model\n\nThis approach is useful when you have a pretrained model and want to add LoRA adapters without modifying the original model code.\n\n## 1. Define an MLP Module\n\nFirst, let's create a reusable MLP module that we'll later adapt with LoRA."
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cell-3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLP module defined!\n"
]
}
],
"source": [
"class MLP(bx.Module):\n",
" \"\"\"Multi-layer perceptron with configurable layers and activation.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" graph: bx.Graph,\n",
" output_sizes: list[int],\n",
" rng: bx.Rng,\n",
" activation=jax.nn.relu,\n",
" ):\n",
" super().__init__(graph)\n",
" self.activation = activation\n",
" self.layers = []\n",
" for i, size in enumerate(output_sizes):\n",
" name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'\n",
" self.layers.append(bx.Linear(graph.child(name), size, rng=rng))\n",
"\n",
" def __call__(self, params, x):\n",
" for i, layer in enumerate(self.layers):\n",
" x, params = layer(params, x)\n",
" # Apply activation to all but the last layer.\n",
" if i < len(self.layers) - 1:\n",
" x = self.activation(x)\n",
" return x, params\n",
"\n",
"\n",
"print('MLP module defined!')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cell-4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model initialized!\n",
"Total parameters: 8\n"
]
}
],
"source": [
"# Create the model.\n",
"graph = bx.Graph('net')\n",
"rng = bx.Rng(graph.child('rng'))\n",
"\n",
"model = MLP(\n",
" graph.child('mlp'),\n",
" output_sizes=[64, 32, 10],\n",
" rng=rng,\n",
" activation=jax.nn.relu,\n",
")\n",
"\n",
"# Initialize the model.\n",
"x = jnp.ones((4, 16))\n",
"params = rng.seed(bx.Params(), seed=42)\n",
"_, params = model(params, x)\n",
"params = params.locked()\n",
"\n",
"print('Model initialized!')\n",
"print(f'Total parameters: {len(params)}')"
]
},
{
"cell_type": "markdown",
"id": "cell-5",
"metadata": {},
"source": [
"## 2. Explore the Graph\n",
"\n",
"blox provides `Graph.walk()` to iterate over all modules in the graph. This is useful for finding which layers to apply LoRA to."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cell-6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Modules in graph:\n",
" ('net', 'rng'): Rng(auto_fold_in_axes=True)\n",
" ('net', 'mlp'): MLP(output_sizes=[64, 32, 10], rng=Rng(auto_fold_in_axes=True), activation=)\n",
" ('net', 'mlp', 'hidden0'): Linear(output_size=64, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=.init at 0x7aade29d76a0>, bias_init=, kernel_metadata=None, bias_metadata=None)\n",
" ('net', 'mlp', 'hidden1'): Linear(output_size=32, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=.init at 0x7aade29d76a0>, bias_init=, kernel_metadata=None, bias_metadata=None)\n",
" ('net', 'mlp', 'output'): Linear(output_size=10, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=.init at 0x7aade29d76a0>, bias_init=, kernel_metadata=None, bias_metadata=None)\n"
]
}
],
"source": [
"# Walk the graph to see all modules.\n",
"print('Modules in graph:')\n",
"for path, module in graph.walk():\n",
" print(f' {path}: {module!r}')"
]
},
{
"cell_type": "markdown",
"id": "cell-8",
"metadata": {},
"source": [
"## 3. The LoRA Pattern\n",
"\n",
"LoRA replaces `y = x @ W` with `y = x @ W + x @ A @ B` where:\n",
"- `W` is frozen (original weights)\n",
"- `A` has shape `(in_features, rank)` - initialized with Gaussian\n",
"- `B` has shape `(rank, out_features)` - initialized to zeros\n",
"\n",
"**Initialization**: Following the [original paper](https://arxiv.org/abs/2106.09685), A is initialized with a Gaussian distribution and B is initialized to zeros. This ensures `A @ B = 0` at the start, so the model begins with its original pretrained behavior.\n",
"\n",
"**Scaling**: The output is scaled by `alpha / rank` to stabilize training across different rank values."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cell-9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LoRA functions defined!\n"
]
}
],
"source": [
"def apply_lora(layer: bx.Linear, rank: int = 4, alpha: float | None = None):\n",
" \"\"\"Applies LoRA to a Linear layer by monkey-patching get_param.\n",
"\n",
" Args:\n",
" layer: The Linear layer to adapt.\n",
" rank: The rank of the low-rank matrices.\n",
" alpha: Scaling factor. Defaults to 2 * rank (common heuristic).\n",
" \"\"\"\n",
" if alpha is None:\n",
" alpha = 2 * rank\n",
"\n",
" # Save the original get_param.\n",
" original_get_param = layer.get_param\n",
"\n",
" def lora_get_param(params, name, shape=None, init=None, rng=None, **kwargs):\n",
" # Get the original parameter.\n",
" value, params = original_get_param(\n",
" params, name, shape, init, rng=rng, **kwargs\n",
" )\n",
"\n",
" # Only apply LoRA to the kernel.\n",
" if name != 'kernel':\n",
" return value, params\n",
"\n",
" in_features, out_features = value.shape\n",
"\n",
" # LoRA initialization (from the original paper):\n",
" # - A: Gaussian with std = 1/sqrt(rank) for stable gradients.\n",
" # - B: Zeros, so LoRA starts as identity (A @ 0 = 0).\n",
" lora_a, params = original_get_param(\n",
" params,\n",
" 'lora_a',\n",
" (in_features, rank),\n",
" jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(rank)),\n",
" rng=rng,\n",
" )\n",
" lora_b, params = original_get_param(\n",
" params,\n",
" 'lora_b',\n",
" (rank, out_features),\n",
" jax.nn.initializers.zeros,\n",
" rng=rng,\n",
" )\n",
"\n",
" # Compute W + (alpha / rank) * A @ B.\n",
" scale = alpha / rank\n",
" adapted_kernel = value + scale * (lora_a @ lora_b)\n",
"\n",
" return adapted_kernel, params\n",
"\n",
" # Replace get_param with our LoRA version.\n",
" layer.get_param = lora_get_param\n",
"\n",
" # Store config for merge/removal.\n",
" layer._lora_original_get_param = original_get_param\n",
"\n",
"\n",
"def remove_lora(layer: bx.Linear):\n",
" \"\"\"Removes LoRA from a layer, restoring original behavior.\"\"\"\n",
" if hasattr(layer, '_lora_original_get_param'):\n",
" layer.get_param = layer._lora_original_get_param\n",
" del layer._lora_original_get_param\n",
"\n",
"\n",
"print('LoRA functions defined!')"
]
},
{
"cell_type": "markdown",
"id": "cell-10",
"metadata": {},
"source": [
"## 4. Apply LoRA to Selected Layers\n",
"\n",
"Let's apply LoRA to the hidden layers but not the output layer."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cell-11",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Applied LoRA to ('net', 'mlp', 'hidden0')\n",
"Applied LoRA to ('net', 'mlp', 'hidden1')\n"
]
}
],
"source": [
"# Apply LoRA to all Linear layers except the output layer.\n",
"for path, module in graph.walk():\n",
" if isinstance(module, bx.Linear) and path[-1] != 'output':\n",
" apply_lora(module, rank=4)\n",
" print(f'Applied LoRA to {path}')"
]
},
{
"cell_type": "markdown",
"id": "cell-12",
"metadata": {},
"source": [
"## 5. Initialize LoRA Parameters\n",
"\n",
"Now we need to run a forward pass to create the LoRA parameters. Since params are locked, we first unlock them."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cell-13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total parameters after LoRA: 12\n",
"\n",
"New LoRA parameters:\n",
" ('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)\n",
" ('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)\n",
" ('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)\n",
" ('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)\n"
]
}
],
"source": [
"# Unlock params to allow new parameters.\n",
"params = params.unlocked()\n",
"\n",
"# Run forward pass to initialize LoRA params.\n",
"_, params = model(params, x)\n",
"\n",
"# Lock again.\n",
"params = params.locked()\n",
"\n",
"print(f'Total parameters after LoRA: {len(params)}')\n",
"print('\\nNew LoRA parameters:')\n",
"for path, param in params.items():\n",
" if 'lora' in path[-1]:\n",
" print(f' {path}: {param.value.shape}')"
]
},
{
"cell_type": "markdown",
"id": "cell-14",
"metadata": {},
"source": [
"## 6. Freeze Base Weights\n",
"\n",
"For LoRA training, we want to freeze the original weights and only train the LoRA parameters."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cell-15",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base weights frozen!\n",
"\n",
"Trainable parameters:\n",
" ('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)\n",
" ('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)\n",
" ('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)\n",
" ('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)\n"
]
}
],
"source": [
"def freeze_base_weights(params, layers):\n",
" \"\"\"Freezes all non-LoRA parameters in the given layers.\"\"\"\n",
" for layer in layers:\n",
" # Freeze kernel and bias.\n",
" params = layer.set_param(params, 'kernel', None, trainable=False)\n",
" if layer.use_bias:\n",
" params = layer.set_param(params, 'bias', None, trainable=False)\n",
" return params\n",
"\n",
"\n",
"# Freeze all base weights.\n",
"params = freeze_base_weights(params, model.layers)\n",
"\n",
"print('Base weights frozen!')\n",
"print('\\nTrainable parameters:')\n",
"for path, param in params.items():\n",
" if param.trainable:\n",
" print(f' {path}: {param.value.shape}')"
]
},
{
"cell_type": "markdown",
"id": "cell-16",
"metadata": {},
"source": [
"## 7. Training with LoRA\n",
"\n",
"Now we can train! Only LoRA parameters will receive gradients."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "cell-17",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial loss: 1.2383\n",
"Final loss: 0.9382\n",
"Loss reduced by 24.2%\n"
]
}
],
"source": [
"# Generate some dummy training data.\n",
"x_train = jax.random.normal(jax.random.key(0), (32, 16))\n",
"y_train = jax.random.normal(jax.random.key(1), (32, 10))\n",
"\n",
"# Save params before training for comparison later.\n",
"params_before_training = jax.tree.map(lambda x: x.copy(), params)\n",
"\n",
"\n",
"@jax.jit(donate_argnames='params')\n",
"def train_step(params, x, y):\n",
" # Split into trainable (LoRA) and non-trainable (base + RNG).\n",
" trainable, non_trainable = params.split()\n",
"\n",
" def loss_fn(t, nt):\n",
" full_params = t.merge(nt)\n",
" pred, new_params = model(full_params, x)\n",
" _, new_nt = new_params.split()\n",
" return jnp.mean((pred - y) ** 2), new_nt\n",
"\n",
" grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(\n",
" trainable, non_trainable\n",
" )\n",
"\n",
" # SGD update.\n",
" new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)\n",
"\n",
" return new_trainable.merge(new_non_trainable)\n",
"\n",
"\n",
"# Compute initial loss.\n",
"pred, _ = model(params, x_train)\n",
"initial_loss = jnp.mean((pred - y_train) ** 2)\n",
"print(f'Initial loss: {initial_loss:.4f}')\n",
"\n",
"# Train for a few steps.\n",
"for step in range(100):\n",
" params = train_step(params, x_train, y_train)\n",
"\n",
"# Compute final loss.\n",
"pred, _ = model(params, x_train)\n",
"final_loss = jnp.mean((pred - y_train) ** 2)\n",
"print(f'Final loss: {final_loss:.4f}')\n",
"print(f'Loss reduced by {(1 - final_loss/initial_loss) * 100:.1f}%')"
]
},
{
"cell_type": "markdown",
"id": "cell-18",
"metadata": {},
"source": [
"## 8. Verify Only LoRA Weights Changed\n",
"\n",
"Let's verify that only the LoRA parameters were updated during training."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "cell-19",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Base weights unchanged:\n",
" ('net', 'mlp', 'hidden0', 'bias'): True\n",
" ('net', 'mlp', 'hidden0', 'kernel'): True\n",
" ('net', 'mlp', 'hidden1', 'bias'): True\n",
" ('net', 'mlp', 'hidden1', 'kernel'): True\n",
" ('net', 'mlp', 'output', 'bias'): True\n",
" ('net', 'mlp', 'output', 'kernel'): True\n",
" ('net', 'rng', 'counter'): True\n",
" ('net', 'rng', 'seed'): True\n"
]
}
],
"source": [
"# Compare base weights before and after training.\n",
"print('Base weights unchanged:')\n",
"for path, param in params_before_training.items():\n",
" if 'lora' not in path[-1]:\n",
" match = jnp.allclose(param.value, params[path].value)\n",
" print(f' {path}: {match}')"
]
},
{
"cell_type": "markdown",
"id": "cell-20",
"metadata": {},
"source": [
"## 9. Merging LoRA Weights (Optional)\n",
"\n",
"For inference, you can merge LoRA weights into the base weights to avoid the extra computation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cell-21",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Outputs match after merge: True\n",
"Parameters after merge: 8 (was 12)\n"
]
}
],
"source": [
"def merge_lora_weights(params, layer):\n",
" \"\"\"Merges LoRA weights into base weights for efficient inference.\"\"\"\n",
" if not hasattr(layer, '_lora_original_get_param'):\n",
" return params\n",
"\n",
" # get_param returns the adapted kernel (W + A @ B) while LoRA is active.\n",
" merged_kernel, _ = layer.get_param(params, 'kernel')\n",
"\n",
" # Update the base kernel with the merged value.\n",
" return layer.set_param(params, 'kernel', merged_kernel)\n",
"\n",
"\n",
"# Store output before merging.\n",
"out_with_lora, _ = model(params, x_train[:1])\n",
"\n",
"# Merge LoRA weights into base kernels.\n",
"merged_params = params\n",
"for layer in model.layers:\n",
" merged_params = merge_lora_weights(merged_params, layer)\n",
"\n",
"# Remove LoRA from layers.\n",
"for layer in model.layers:\n",
" remove_lora(layer)\n",
"\n",
"\n",
"# Remove LoRA params from the container.\n",
"def is_lora_param(path, param):\n",
" return 'lora' in path[-1]\n",
"\n",
"\n",
"lora_params, merged_params = merged_params.split(is_lora_param)\n",
"\n",
"# Test that output is the same.\n",
"out_merged, _ = model(merged_params, x_train[:1])\n",
"outputs_match = jnp.allclose(out_with_lora, out_merged, atol=1e-5)\n",
"print(f'Outputs match after merge: {outputs_match}')\n",
"print(f'Parameters after merge: {len(merged_params)} (was {len(params)})')"
]
},
{
"cell_type": "markdown",
"id": "583y5f6xhos",
"source": "---\n\n# Approach 2: LoRA-aware Model Design\n\nInstead of monkey-patching an existing model, you can design your model to support LoRA from the start. This approach is cleaner and more explicit.\n\nThe idea is to create a `LoraLinear` wrapper that:\n- Wraps a standard `bx.Linear` layer\n- Adds LoRA parameters alongside the base parameters\n- Can be enabled/disabled via a flag",
"metadata": {}
},
{
"cell_type": "code",
"id": "fran0nkauvl",
"source": "class LoraLinear(bx.Module):\n \"\"\"A Linear layer with optional LoRA adaptation.\n\n This wrapper creates both the base Linear layer and LoRA parameters.\n LoRA can be enabled/disabled without changing the model structure.\n \"\"\"\n\n def __init__(\n self,\n graph: bx.Graph,\n output_size: int,\n rng: bx.Rng,\n rank: int = 4,\n alpha: float | None = None,\n use_lora: bool = True,\n ):\n super().__init__(graph)\n self.rank = rank\n self.alpha = alpha if alpha is not None else 2 * rank\n self.use_lora = use_lora\n self.rng = rng\n\n # Create the base Linear layer.\n self.linear = bx.Linear(graph.child('base'), output_size, rng=rng)\n\n def __call__(self, params, x):\n # Get base output.\n out, params = self.linear(params, x)\n\n if not self.use_lora:\n return out, params\n\n # Get or create LoRA parameters.\n in_features = x.shape[-1]\n out_features = self.linear.output_size\n\n lora_a, params = self.get_param(\n params,\n 'lora_a',\n (in_features, self.rank),\n jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(self.rank)),\n rng=self.rng,\n )\n lora_b, params = self.get_param(\n params,\n 'lora_b',\n (self.rank, out_features),\n jax.nn.initializers.zeros,\n rng=self.rng,\n )\n\n # Add LoRA contribution: scale * x @ A @ B\n scale = self.alpha / self.rank\n out = out + scale * (x @ lora_a @ lora_b)\n\n return out, params\n\n\nprint('LoraLinear module defined!')",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "8keo2vlyfil",
"source": "## Create a LoRA-aware MLP\n\nNow we can create an MLP that uses `LoraLinear` layers. The model supports LoRA from the start - no monkey-patching needed.",
"metadata": {}
},
{
"cell_type": "code",
"id": "xrzwgvpxjt",
"source": "class LoraAwareMLP(bx.Module):\n \"\"\"MLP that supports LoRA from the start.\"\"\"\n\n def __init__(\n self,\n graph: bx.Graph,\n output_sizes: list[int],\n rng: bx.Rng,\n lora_rank: int = 4,\n use_lora: bool = True,\n activation=jax.nn.relu,\n ):\n super().__init__(graph)\n self.activation = activation\n self.layers = []\n for i, size in enumerate(output_sizes):\n name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'\n # Use LoraLinear for hidden layers, regular Linear for output.\n if i < len(output_sizes) - 1:\n layer = LoraLinear(\n graph.child(name), size, rng=rng, rank=lora_rank, use_lora=use_lora\n )\n else:\n layer = bx.Linear(graph.child(name), size, rng=rng)\n self.layers.append(layer)\n\n def __call__(self, params, x):\n for i, layer in enumerate(self.layers):\n x, params = layer(params, x)\n if i < len(self.layers) - 1:\n x = self.activation(x)\n return x, params\n\n\n# Create and initialize the LoRA-aware model.\ngraph2 = bx.Graph('lora_net')\nrng2 = bx.Rng(graph2.child('rng'))\n\nlora_model = LoraAwareMLP(\n graph2.child('mlp'),\n output_sizes=[64, 32, 10],\n rng=rng2,\n lora_rank=4,\n use_lora=True,\n)\n\n# Initialize.\nx2 = jnp.ones((4, 16))\nparams2 = rng2.seed(bx.Params(), seed=42)\n_, params2 = lora_model(params2, x2)\nparams2 = params2.locked()\n\nprint('LoRA-aware model initialized!')\nprint(f'Total parameters: {len(params2)}')\nprint('\\nParameter structure:')\nfor path, param in params2.items():\n if 'lora' in path[-1] or 'kernel' in path[-1]:\n print(f' {path}: {param.value.shape}')",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "kvjwq4hw43i",
"source": "## Benefits of the LoRA-aware Approach\n\nNotice how the LoRA parameters are created alongside the base parameters during initialization - no need to unlock/relock params or run an extra forward pass.\n\nTo fine-tune, we still freeze the base weights and train only the LoRA parameters:",
"metadata": {}
},
{
"cell_type": "code",
"id": "gndw2vefmt",
"source": "# Freeze base weights (kernel and bias in the nested 'base' Linear layers).\nfor layer in lora_model.layers:\n if isinstance(layer, LoraLinear):\n params2 = layer.linear.set_param(params2, 'kernel', None, trainable=False)\n params2 = layer.linear.set_param(params2, 'bias', None, trainable=False)\n elif isinstance(layer, bx.Linear):\n params2 = layer.set_param(params2, 'kernel', None, trainable=False)\n params2 = layer.set_param(params2, 'bias', None, trainable=False)\n\nprint('Trainable parameters (LoRA only):')\nfor path, param in params2.items():\n if param.trainable:\n print(f' {path}: {param.value.shape}')",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "umzio4wxdse",
"source": "## Toggling LoRA\n\nWith the LoRA-aware design, you can easily disable LoRA by setting `use_lora=False` on each layer. This is useful for comparing adapted vs original model behavior:",
"metadata": {}
},
{
"cell_type": "code",
"id": "6nwspxjhb9p",
"source": "# Output with LoRA enabled.\nout_with_lora, _ = lora_model(params2, x2[:1])\n\n# Disable LoRA on all layers.\nfor layer in lora_model.layers:\n if isinstance(layer, LoraLinear):\n layer.use_lora = False\n\n# Output without LoRA (base model only).\nout_without_lora, _ = lora_model(params2, x2[:1])\n\n# Re-enable LoRA.\nfor layer in lora_model.layers:\n if isinstance(layer, LoraLinear):\n layer.use_lora = True\n\nprint('Output with LoRA:', out_with_lora[0, :5])\nprint('Output without LoRA:', out_without_lora[0, :5])\nprint(f'\\nOutputs differ: {not jnp.allclose(out_with_lora, out_without_lora)}')",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "cell-22",
"metadata": {},
"source": "markdown## Summary\n\nThis notebook demonstrated two approaches to implementing LoRA with blox:\n\n### Approach 1: Monkey-patching\nBest for adapting existing pretrained models without modifying their code.\n\n1. **Graph traversal** - Used `Graph.walk()` to find target layers\n2. **Monkey-patching** - Wrapped `get_param` to inject `W + A @ B`\n3. **Parameter freezing** - Marked base weights as non-trainable\n4. **Weight merging** - Combined LoRA weights back into base weights for inference\n\n### Approach 2: LoRA-aware Design\nBest for new models where you plan to use LoRA from the start.\n\n1. **Wrapper layer** - Created `LoraLinear` that encapsulates base + LoRA parameters\n2. **Clean initialization** - LoRA parameters created alongside base parameters\n3. **Easy toggling** - Enable/disable LoRA via `use_lora` flag\n4. **Explicit structure** - Model structure shows LoRA is supported\n\n### Which to Choose?\n\n| Scenario | Recommended Approach |\n|----------|---------------------|\n| Adapting a pretrained model | Monkey-patching |\n| Building a new model with LoRA support | LoRA-aware design |\n| Research/experimentation | Either (monkey-patching is more flexible) |\n| Production deployment | LoRA-aware (cleaner, more maintainable) |\n\n### References\n\n- [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)\n- [Practical Tips for Finetuning LLMs Using LoRA](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms)"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}