{ "cells": [ { "cell_type": "markdown", "id": "cell-0", "metadata": {}, "source": "markdown# LoRA (Low-Rank Adaptation) with blox\n\nThis notebook demonstrates how to implement LoRA fine-tuning with blox.\n\nLoRA is a parameter-efficient fine-tuning technique that freezes the pretrained model weights and injects trainable low-rank decomposition matrices. Instead of fine-tuning `W`, we compute `W + A @ B` where `A` and `B` are small matrices.\n\nKey benefits:\n- **Memory efficient**: Only LoRA parameters need gradients\n- **Modular**: Can add/remove adapters without changing base model\n- **Stackable**: Multiple adapters can be combined\n\nReference: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)\n\n## Two Approaches\n\nThere are multiple ways to add LoRA to a model with blox:\n\n1. **Monkey-patching** (shown first): Take an existing pretrained model and wrap its layers with LoRA at runtime. This is useful when you have a pretrained model and want to add adapters without modifying the original code.\n\n2. **LoRA-aware design** (shown later): Design your model to support LoRA from the start by using wrapper layers. This is cleaner and more explicit, but requires planning ahead.\n\nBoth approaches are valid - choose based on whether you're adapting an existing model or building a new one." }, { "cell_type": "code", "execution_count": 1, "id": "cell-1", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.insert(0, '../src')\n", "\n", "import blox as bx\n", "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "id": "cell-2", "metadata": {}, "source": "markdown---\n\n# Approach 1: Monkey-patching an Existing Model\n\nThis approach is useful when you have a pretrained model and want to add LoRA adapters without modifying the original model code.\n\n## 1. Define an MLP Module\n\nFirst, let's create a reusable MLP module that we'll later adapt with LoRA." }, { "cell_type": "code", "execution_count": 2, "id": "cell-3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP module defined!\n" ] } ], "source": [ "class MLP(bx.Module):\n", " \"\"\"Multi-layer perceptron with configurable layers and activation.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " graph: bx.Graph,\n", " output_sizes: list[int],\n", " rng: bx.Rng,\n", " activation=jax.nn.relu,\n", " ):\n", " super().__init__(graph)\n", " self.activation = activation\n", " self.layers = []\n", " for i, size in enumerate(output_sizes):\n", " name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'\n", " self.layers.append(bx.Linear(graph.child(name), size, rng=rng))\n", "\n", " def __call__(self, params, x):\n", " for i, layer in enumerate(self.layers):\n", " x, params = layer(params, x)\n", " # Apply activation to all but the last layer.\n", " if i < len(self.layers) - 1:\n", " x = self.activation(x)\n", " return x, params\n", "\n", "\n", "print('MLP module defined!')" ] }, { "cell_type": "code", "execution_count": 3, "id": "cell-4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model initialized!\n", "Total parameters: 8\n" ] } ], "source": [ "# Create the model.\n", "graph = bx.Graph('net')\n", "rng = bx.Rng(graph.child('rng'))\n", "\n", "model = MLP(\n", " graph.child('mlp'),\n", " output_sizes=[64, 32, 10],\n", " rng=rng,\n", " activation=jax.nn.relu,\n", ")\n", "\n", "# Initialize the model.\n", "x = jnp.ones((4, 16))\n", "params = rng.seed(bx.Params(), seed=42)\n", "_, params = model(params, x)\n", "params = params.locked()\n", "\n", "print('Model initialized!')\n", "print(f'Total parameters: {len(params)}')" ] }, { "cell_type": "markdown", "id": "cell-5", "metadata": {}, "source": [ "## 2. Explore the Graph\n", "\n", "blox provides `Graph.walk()` to iterate over all modules in the graph. This is useful for finding which layers to apply LoRA to." ] }, { "cell_type": "code", "execution_count": null, "id": "cell-6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Modules in graph:\n", " ('net', 'rng'): Rng(auto_fold_in_axes=True)\n", " ('net', 'mlp'): MLP(output_sizes=[64, 32, 10], rng=Rng(auto_fold_in_axes=True), activation=)\n", " ('net', 'mlp', 'hidden0'): Linear(output_size=64, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=.init at 0x7aade29d76a0>, bias_init=, kernel_metadata=None, bias_metadata=None)\n", " ('net', 'mlp', 'hidden1'): Linear(output_size=32, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=.init at 0x7aade29d76a0>, bias_init=, kernel_metadata=None, bias_metadata=None)\n", " ('net', 'mlp', 'output'): Linear(output_size=10, rng=Rng(auto_fold_in_axes=True), use_bias=True, kernel_init=.init at 0x7aade29d76a0>, bias_init=, kernel_metadata=None, bias_metadata=None)\n" ] } ], "source": [ "# Walk the graph to see all modules.\n", "print('Modules in graph:')\n", "for path, module in graph.walk():\n", " print(f' {path}: {module!r}')" ] }, { "cell_type": "markdown", "id": "cell-8", "metadata": {}, "source": [ "## 3. The LoRA Pattern\n", "\n", "LoRA replaces `y = x @ W` with `y = x @ W + x @ A @ B` where:\n", "- `W` is frozen (original weights)\n", "- `A` has shape `(in_features, rank)` - initialized with Gaussian\n", "- `B` has shape `(rank, out_features)` - initialized to zeros\n", "\n", "**Initialization**: Following the [original paper](https://arxiv.org/abs/2106.09685), A is initialized with a Gaussian distribution and B is initialized to zeros. This ensures `A @ B = 0` at the start, so the model begins with its original pretrained behavior.\n", "\n", "**Scaling**: The output is scaled by `alpha / rank` to stabilize training across different rank values." ] }, { "cell_type": "code", "execution_count": 5, "id": "cell-9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LoRA functions defined!\n" ] } ], "source": [ "def apply_lora(layer: bx.Linear, rank: int = 4, alpha: float | None = None):\n", " \"\"\"Applies LoRA to a Linear layer by monkey-patching get_param.\n", "\n", " Args:\n", " layer: The Linear layer to adapt.\n", " rank: The rank of the low-rank matrices.\n", " alpha: Scaling factor. Defaults to 2 * rank (common heuristic).\n", " \"\"\"\n", " if alpha is None:\n", " alpha = 2 * rank\n", "\n", " # Save the original get_param.\n", " original_get_param = layer.get_param\n", "\n", " def lora_get_param(params, name, shape=None, init=None, rng=None, **kwargs):\n", " # Get the original parameter.\n", " value, params = original_get_param(\n", " params, name, shape, init, rng=rng, **kwargs\n", " )\n", "\n", " # Only apply LoRA to the kernel.\n", " if name != 'kernel':\n", " return value, params\n", "\n", " in_features, out_features = value.shape\n", "\n", " # LoRA initialization (from the original paper):\n", " # - A: Gaussian with std = 1/sqrt(rank) for stable gradients.\n", " # - B: Zeros, so LoRA starts as identity (A @ 0 = 0).\n", " lora_a, params = original_get_param(\n", " params,\n", " 'lora_a',\n", " (in_features, rank),\n", " jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(rank)),\n", " rng=rng,\n", " )\n", " lora_b, params = original_get_param(\n", " params,\n", " 'lora_b',\n", " (rank, out_features),\n", " jax.nn.initializers.zeros,\n", " rng=rng,\n", " )\n", "\n", " # Compute W + (alpha / rank) * A @ B.\n", " scale = alpha / rank\n", " adapted_kernel = value + scale * (lora_a @ lora_b)\n", "\n", " return adapted_kernel, params\n", "\n", " # Replace get_param with our LoRA version.\n", " layer.get_param = lora_get_param\n", "\n", " # Store config for merge/removal.\n", " layer._lora_original_get_param = original_get_param\n", "\n", "\n", "def remove_lora(layer: bx.Linear):\n", " \"\"\"Removes LoRA from a layer, restoring original behavior.\"\"\"\n", " if hasattr(layer, '_lora_original_get_param'):\n", " layer.get_param = layer._lora_original_get_param\n", " del layer._lora_original_get_param\n", "\n", "\n", "print('LoRA functions defined!')" ] }, { "cell_type": "markdown", "id": "cell-10", "metadata": {}, "source": [ "## 4. Apply LoRA to Selected Layers\n", "\n", "Let's apply LoRA to the hidden layers but not the output layer." ] }, { "cell_type": "code", "execution_count": 6, "id": "cell-11", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Applied LoRA to ('net', 'mlp', 'hidden0')\n", "Applied LoRA to ('net', 'mlp', 'hidden1')\n" ] } ], "source": [ "# Apply LoRA to all Linear layers except the output layer.\n", "for path, module in graph.walk():\n", " if isinstance(module, bx.Linear) and path[-1] != 'output':\n", " apply_lora(module, rank=4)\n", " print(f'Applied LoRA to {path}')" ] }, { "cell_type": "markdown", "id": "cell-12", "metadata": {}, "source": [ "## 5. Initialize LoRA Parameters\n", "\n", "Now we need to run a forward pass to create the LoRA parameters. Since params are locked, we first unlock them." ] }, { "cell_type": "code", "execution_count": 7, "id": "cell-13", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total parameters after LoRA: 12\n", "\n", "New LoRA parameters:\n", " ('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)\n", " ('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)\n", " ('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)\n", " ('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)\n" ] } ], "source": [ "# Unlock params to allow new parameters.\n", "params = params.unlocked()\n", "\n", "# Run forward pass to initialize LoRA params.\n", "_, params = model(params, x)\n", "\n", "# Lock again.\n", "params = params.locked()\n", "\n", "print(f'Total parameters after LoRA: {len(params)}')\n", "print('\\nNew LoRA parameters:')\n", "for path, param in params.items():\n", " if 'lora' in path[-1]:\n", " print(f' {path}: {param.value.shape}')" ] }, { "cell_type": "markdown", "id": "cell-14", "metadata": {}, "source": [ "## 6. Freeze Base Weights\n", "\n", "For LoRA training, we want to freeze the original weights and only train the LoRA parameters." ] }, { "cell_type": "code", "execution_count": 8, "id": "cell-15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Base weights frozen!\n", "\n", "Trainable parameters:\n", " ('net', 'mlp', 'hidden0', 'lora_a'): (16, 4)\n", " ('net', 'mlp', 'hidden0', 'lora_b'): (4, 64)\n", " ('net', 'mlp', 'hidden1', 'lora_a'): (64, 4)\n", " ('net', 'mlp', 'hidden1', 'lora_b'): (4, 32)\n" ] } ], "source": [ "def freeze_base_weights(params, layers):\n", " \"\"\"Freezes all non-LoRA parameters in the given layers.\"\"\"\n", " for layer in layers:\n", " # Freeze kernel and bias.\n", " params = layer.set_param(params, 'kernel', None, trainable=False)\n", " if layer.use_bias:\n", " params = layer.set_param(params, 'bias', None, trainable=False)\n", " return params\n", "\n", "\n", "# Freeze all base weights.\n", "params = freeze_base_weights(params, model.layers)\n", "\n", "print('Base weights frozen!')\n", "print('\\nTrainable parameters:')\n", "for path, param in params.items():\n", " if param.trainable:\n", " print(f' {path}: {param.value.shape}')" ] }, { "cell_type": "markdown", "id": "cell-16", "metadata": {}, "source": [ "## 7. Training with LoRA\n", "\n", "Now we can train! Only LoRA parameters will receive gradients." ] }, { "cell_type": "code", "execution_count": 9, "id": "cell-17", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial loss: 1.2383\n", "Final loss: 0.9382\n", "Loss reduced by 24.2%\n" ] } ], "source": [ "# Generate some dummy training data.\n", "x_train = jax.random.normal(jax.random.key(0), (32, 16))\n", "y_train = jax.random.normal(jax.random.key(1), (32, 10))\n", "\n", "# Save params before training for comparison later.\n", "params_before_training = jax.tree.map(lambda x: x.copy(), params)\n", "\n", "\n", "@jax.jit(donate_argnames='params')\n", "def train_step(params, x, y):\n", " # Split into trainable (LoRA) and non-trainable (base + RNG).\n", " trainable, non_trainable = params.split()\n", "\n", " def loss_fn(t, nt):\n", " full_params = t.merge(nt)\n", " pred, new_params = model(full_params, x)\n", " _, new_nt = new_params.split()\n", " return jnp.mean((pred - y) ** 2), new_nt\n", "\n", " grads, new_non_trainable = jax.grad(loss_fn, has_aux=True)(\n", " trainable, non_trainable\n", " )\n", "\n", " # SGD update.\n", " new_trainable = jax.tree.map(lambda w, g: w - 0.01 * g, trainable, grads)\n", "\n", " return new_trainable.merge(new_non_trainable)\n", "\n", "\n", "# Compute initial loss.\n", "pred, _ = model(params, x_train)\n", "initial_loss = jnp.mean((pred - y_train) ** 2)\n", "print(f'Initial loss: {initial_loss:.4f}')\n", "\n", "# Train for a few steps.\n", "for step in range(100):\n", " params = train_step(params, x_train, y_train)\n", "\n", "# Compute final loss.\n", "pred, _ = model(params, x_train)\n", "final_loss = jnp.mean((pred - y_train) ** 2)\n", "print(f'Final loss: {final_loss:.4f}')\n", "print(f'Loss reduced by {(1 - final_loss/initial_loss) * 100:.1f}%')" ] }, { "cell_type": "markdown", "id": "cell-18", "metadata": {}, "source": [ "## 8. Verify Only LoRA Weights Changed\n", "\n", "Let's verify that only the LoRA parameters were updated during training." ] }, { "cell_type": "code", "execution_count": 10, "id": "cell-19", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Base weights unchanged:\n", " ('net', 'mlp', 'hidden0', 'bias'): True\n", " ('net', 'mlp', 'hidden0', 'kernel'): True\n", " ('net', 'mlp', 'hidden1', 'bias'): True\n", " ('net', 'mlp', 'hidden1', 'kernel'): True\n", " ('net', 'mlp', 'output', 'bias'): True\n", " ('net', 'mlp', 'output', 'kernel'): True\n", " ('net', 'rng', 'counter'): True\n", " ('net', 'rng', 'seed'): True\n" ] } ], "source": [ "# Compare base weights before and after training.\n", "print('Base weights unchanged:')\n", "for path, param in params_before_training.items():\n", " if 'lora' not in path[-1]:\n", " match = jnp.allclose(param.value, params[path].value)\n", " print(f' {path}: {match}')" ] }, { "cell_type": "markdown", "id": "cell-20", "metadata": {}, "source": [ "## 9. Merging LoRA Weights (Optional)\n", "\n", "For inference, you can merge LoRA weights into the base weights to avoid the extra computation." ] }, { "cell_type": "code", "execution_count": null, "id": "cell-21", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Outputs match after merge: True\n", "Parameters after merge: 8 (was 12)\n" ] } ], "source": [ "def merge_lora_weights(params, layer):\n", " \"\"\"Merges LoRA weights into base weights for efficient inference.\"\"\"\n", " if not hasattr(layer, '_lora_original_get_param'):\n", " return params\n", "\n", " # get_param returns the adapted kernel (W + A @ B) while LoRA is active.\n", " merged_kernel, _ = layer.get_param(params, 'kernel')\n", "\n", " # Update the base kernel with the merged value.\n", " return layer.set_param(params, 'kernel', merged_kernel)\n", "\n", "\n", "# Store output before merging.\n", "out_with_lora, _ = model(params, x_train[:1])\n", "\n", "# Merge LoRA weights into base kernels.\n", "merged_params = params\n", "for layer in model.layers:\n", " merged_params = merge_lora_weights(merged_params, layer)\n", "\n", "# Remove LoRA from layers.\n", "for layer in model.layers:\n", " remove_lora(layer)\n", "\n", "\n", "# Remove LoRA params from the container.\n", "def is_lora_param(path, param):\n", " return 'lora' in path[-1]\n", "\n", "\n", "lora_params, merged_params = merged_params.split(is_lora_param)\n", "\n", "# Test that output is the same.\n", "out_merged, _ = model(merged_params, x_train[:1])\n", "outputs_match = jnp.allclose(out_with_lora, out_merged, atol=1e-5)\n", "print(f'Outputs match after merge: {outputs_match}')\n", "print(f'Parameters after merge: {len(merged_params)} (was {len(params)})')" ] }, { "cell_type": "markdown", "id": "583y5f6xhos", "source": "---\n\n# Approach 2: LoRA-aware Model Design\n\nInstead of monkey-patching an existing model, you can design your model to support LoRA from the start. This approach is cleaner and more explicit.\n\nThe idea is to create a `LoraLinear` wrapper that:\n- Wraps a standard `bx.Linear` layer\n- Adds LoRA parameters alongside the base parameters\n- Can be enabled/disabled via a flag", "metadata": {} }, { "cell_type": "code", "id": "fran0nkauvl", "source": "class LoraLinear(bx.Module):\n \"\"\"A Linear layer with optional LoRA adaptation.\n\n This wrapper creates both the base Linear layer and LoRA parameters.\n LoRA can be enabled/disabled without changing the model structure.\n \"\"\"\n\n def __init__(\n self,\n graph: bx.Graph,\n output_size: int,\n rng: bx.Rng,\n rank: int = 4,\n alpha: float | None = None,\n use_lora: bool = True,\n ):\n super().__init__(graph)\n self.rank = rank\n self.alpha = alpha if alpha is not None else 2 * rank\n self.use_lora = use_lora\n self.rng = rng\n\n # Create the base Linear layer.\n self.linear = bx.Linear(graph.child('base'), output_size, rng=rng)\n\n def __call__(self, params, x):\n # Get base output.\n out, params = self.linear(params, x)\n\n if not self.use_lora:\n return out, params\n\n # Get or create LoRA parameters.\n in_features = x.shape[-1]\n out_features = self.linear.output_size\n\n lora_a, params = self.get_param(\n params,\n 'lora_a',\n (in_features, self.rank),\n jax.nn.initializers.normal(stddev=1.0 / jnp.sqrt(self.rank)),\n rng=self.rng,\n )\n lora_b, params = self.get_param(\n params,\n 'lora_b',\n (self.rank, out_features),\n jax.nn.initializers.zeros,\n rng=self.rng,\n )\n\n # Add LoRA contribution: scale * x @ A @ B\n scale = self.alpha / self.rank\n out = out + scale * (x @ lora_a @ lora_b)\n\n return out, params\n\n\nprint('LoraLinear module defined!')", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "id": "8keo2vlyfil", "source": "## Create a LoRA-aware MLP\n\nNow we can create an MLP that uses `LoraLinear` layers. The model supports LoRA from the start - no monkey-patching needed.", "metadata": {} }, { "cell_type": "code", "id": "xrzwgvpxjt", "source": "class LoraAwareMLP(bx.Module):\n \"\"\"MLP that supports LoRA from the start.\"\"\"\n\n def __init__(\n self,\n graph: bx.Graph,\n output_sizes: list[int],\n rng: bx.Rng,\n lora_rank: int = 4,\n use_lora: bool = True,\n activation=jax.nn.relu,\n ):\n super().__init__(graph)\n self.activation = activation\n self.layers = []\n for i, size in enumerate(output_sizes):\n name = f'hidden{i}' if i < len(output_sizes) - 1 else 'output'\n # Use LoraLinear for hidden layers, regular Linear for output.\n if i < len(output_sizes) - 1:\n layer = LoraLinear(\n graph.child(name), size, rng=rng, rank=lora_rank, use_lora=use_lora\n )\n else:\n layer = bx.Linear(graph.child(name), size, rng=rng)\n self.layers.append(layer)\n\n def __call__(self, params, x):\n for i, layer in enumerate(self.layers):\n x, params = layer(params, x)\n if i < len(self.layers) - 1:\n x = self.activation(x)\n return x, params\n\n\n# Create and initialize the LoRA-aware model.\ngraph2 = bx.Graph('lora_net')\nrng2 = bx.Rng(graph2.child('rng'))\n\nlora_model = LoraAwareMLP(\n graph2.child('mlp'),\n output_sizes=[64, 32, 10],\n rng=rng2,\n lora_rank=4,\n use_lora=True,\n)\n\n# Initialize.\nx2 = jnp.ones((4, 16))\nparams2 = rng2.seed(bx.Params(), seed=42)\n_, params2 = lora_model(params2, x2)\nparams2 = params2.locked()\n\nprint('LoRA-aware model initialized!')\nprint(f'Total parameters: {len(params2)}')\nprint('\\nParameter structure:')\nfor path, param in params2.items():\n if 'lora' in path[-1] or 'kernel' in path[-1]:\n print(f' {path}: {param.value.shape}')", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "id": "kvjwq4hw43i", "source": "## Benefits of the LoRA-aware Approach\n\nNotice how the LoRA parameters are created alongside the base parameters during initialization - no need to unlock/relock params or run an extra forward pass.\n\nTo fine-tune, we still freeze the base weights and train only the LoRA parameters:", "metadata": {} }, { "cell_type": "code", "id": "gndw2vefmt", "source": "# Freeze base weights (kernel and bias in the nested 'base' Linear layers).\nfor layer in lora_model.layers:\n if isinstance(layer, LoraLinear):\n params2 = layer.linear.set_param(params2, 'kernel', None, trainable=False)\n params2 = layer.linear.set_param(params2, 'bias', None, trainable=False)\n elif isinstance(layer, bx.Linear):\n params2 = layer.set_param(params2, 'kernel', None, trainable=False)\n params2 = layer.set_param(params2, 'bias', None, trainable=False)\n\nprint('Trainable parameters (LoRA only):')\nfor path, param in params2.items():\n if param.trainable:\n print(f' {path}: {param.value.shape}')", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "id": "umzio4wxdse", "source": "## Toggling LoRA\n\nWith the LoRA-aware design, you can easily disable LoRA by setting `use_lora=False` on each layer. This is useful for comparing adapted vs original model behavior:", "metadata": {} }, { "cell_type": "code", "id": "6nwspxjhb9p", "source": "# Output with LoRA enabled.\nout_with_lora, _ = lora_model(params2, x2[:1])\n\n# Disable LoRA on all layers.\nfor layer in lora_model.layers:\n if isinstance(layer, LoraLinear):\n layer.use_lora = False\n\n# Output without LoRA (base model only).\nout_without_lora, _ = lora_model(params2, x2[:1])\n\n# Re-enable LoRA.\nfor layer in lora_model.layers:\n if isinstance(layer, LoraLinear):\n layer.use_lora = True\n\nprint('Output with LoRA:', out_with_lora[0, :5])\nprint('Output without LoRA:', out_without_lora[0, :5])\nprint(f'\\nOutputs differ: {not jnp.allclose(out_with_lora, out_without_lora)}')", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "id": "cell-22", "metadata": {}, "source": "markdown## Summary\n\nThis notebook demonstrated two approaches to implementing LoRA with blox:\n\n### Approach 1: Monkey-patching\nBest for adapting existing pretrained models without modifying their code.\n\n1. **Graph traversal** - Used `Graph.walk()` to find target layers\n2. **Monkey-patching** - Wrapped `get_param` to inject `W + A @ B`\n3. **Parameter freezing** - Marked base weights as non-trainable\n4. **Weight merging** - Combined LoRA weights back into base weights for inference\n\n### Approach 2: LoRA-aware Design\nBest for new models where you plan to use LoRA from the start.\n\n1. **Wrapper layer** - Created `LoraLinear` that encapsulates base + LoRA parameters\n2. **Clean initialization** - LoRA parameters created alongside base parameters\n3. **Easy toggling** - Enable/disable LoRA via `use_lora` flag\n4. **Explicit structure** - Model structure shows LoRA is supported\n\n### Which to Choose?\n\n| Scenario | Recommended Approach |\n|----------|---------------------|\n| Adapting a pretrained model | Monkey-patching |\n| Building a new model with LoRA support | LoRA-aware design |\n| Research/experimentation | Either (monkey-patching is more flexible) |\n| Production deployment | LoRA-aware (cleaner, more maintainable) |\n\n### References\n\n- [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)\n- [Practical Tips for Finetuning LLMs Using LoRA](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms)" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 5 }