Modern Large Language Models (LLMs) are powerful general-purpose tools that need to be trained in multiple stages involving different methodologies. For the purpose of this post, we will discuss two of them.
The first is the foundational pre-training stage, where the model is trained from scratch on a massive, diverse corpus of text. The objective is purely self-supervised learning, and the model typically learns to predict the next token in a given sentence. This step is computationally intensive, and its objective is to allow the model to learn general representations of language. At the end of this stage, the model is a powerful generalist but is not specialized for any specific task.
The second stage is fine-tuning, where the base model is trained using a much smaller, curated dataset for a specific task (e.g., becoming an assistant that follows instructions). The goal is to adapt the model for this task without it losing the general knowledge acquired during pre-training.
In this context, Parameter-Efficient Fine-Tuning (PEFT), as the name suggests, is a family of methods that allow us to perform the fine-tuning step in a more efficient and controlled manner.
The most straightforward method to adapt a pre-trained model is to continue the training process, updating every single weight using a new, task-specific dataset. This approach is known as full fine-tuning (FFT).
However, full fine-tuning presents several significant drawbacks. First, by directly modifying all pre-trained weights, it risks erasing the valuable general-purpose knowledge they contain, a phenomenon known as catastrophic forgetting.
Second, updating the entire model is computationally expensive. It requires a large amount of GPU memory to store not only the gradients for every parameter but also the memory-intensive states required by optimizers like AdamW. For a 7-billion-parameter model, this can easily exceed 80 GB of VRAM.
Another challenge is storage and deployment. Full fine-tuning requires creating and storing a complete, independent copy of the model for each task. Fine-tuning a 7B model for five different tasks would mean storing five separate 14 GB models, totaling 70 GB.
The core idea behind all PEFT methods is to freeze the vast majority of the pre-trained model’s parameters and only update a small, targeted subset (often less than 1%). This drastically reduces the memory and compute required for training and mitigates catastrophic forgetting.
There are several families of PEFT methods:
This post will focus on Low-Rank Adaptation (LoRA), which has become the most dominant and widely-used PEFT method due to its unique combination of performance and efficiency.
LoRA (Hu, Edward J., et al
Reminder: The rank of a matrix is the number of linearly independent columns (or rows) it has. Intuitively, it measures the “dimensionality” of the information the matrix contains. A low-rank matrix is one where this information is highly redundant.
Instead of updating the original pre-trained weight matrix $W$, LoRA represents the update $\Delta W$ as the product of two much smaller matrices, $A$ (the “down-projection”) and $B$ (the “up-projection”):
\[\Delta W \approx A \cdot B\]Where if the original weight matrix $W$ has dimensions $d \times k$, the new matrices will have dimensions $A$ ($d \times r$) and $B$ ($r \times k$). The hyperparameter r
is the rank of the decomposition, and it is much smaller than $d$ or $k$. This leads to a massive reduction in trainable parameters.
For a typical 4096x4096 attention matrix (~16.7M params), using LoRA with a rank of $r=8$ requires training only two matrices 4096x8 and 8x4096, totaling ~65k parameters. This represents a reduction of over 99.6%.
During the forward pass, the model’s output is calculated by adding the output of the frozen pre-trained layer to the output of the new LoRA path:
\[y=Wx + BAx = (W + BA)x\]LoRA has become the default PEFT method due to its unique combination of efficiency and performance, offering key advantages over other strategies:
LoRA is strategically applied to layers where adaptation is most effective. The most common targets are the linear projection matrices within the multi-head attention (MHA) block, specifically the Query ($W_Q$) and Value ($W_V$) matrices. Adapting how the model queries for and extracts information is a highly efficient way to specialize it for a new task. For more complex adaptations, LoRA can also be applied to the other attention projections ($W_K, W_O$) and the layers of the feed-forward network (FFN).
The LoRA matrix $A$ is typically initialized with random Gaussian values (e.g., Kaiming initialization), while $B$ is initialized to all zeros. This ensures the update is zero at the start of training. An additional hyperparameter $\alpha$ scales the update, allowing control over the adaptation strength independently of the rank r
. The forward pass becomes:
class LoRALayer(nn.Module):
def __init__(self, W: nn.Linear, r: int, alpha: float):
super().__init__()
self.r = r
self.alpha = alpha
device = W.weight.device
d, k = (W.in_features, W.out_features)
self.W = W
self.lora_A = nn.Parameter(torch.randn(d, r, device=device))
self.lora_B = nn.Parameter(torch.zeros(r, k, device=device))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
self.merged = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.merged:
frozen_out = self.W(x)
lora_update = (self.alpha / self.r) * (x @ self.lora_A @ self.lora_B)
return frozen_out + lora_update
else:
return self.W(x)
@torch.no_grad()
def merge(self):
if not self.merged:
update = (self.alpha / self.r) * (self.lora_A @ self.lora_B)
self.W.weight.data += update.T
self.merged = True
@torch.no_grad()
def unmerge(self):
if self.merged:
update = (self.alpha / self.r) * (self.lora_A @ self.lora_B)
self.W.weight.data -= update.T
self.merged = False
To integrate LoRA into an existing PyTorch model without having to modify its implementation, we can write helper functions that handle the process of replacing existing linear layers with LoRA layers, and then after fine-tuning, re-inserting the original layers now with the updated weights. Our strategy is to dynamically modify the pre-trained model in-place.
The get_lora_model
function prepares a standard Transformer for LoRA fine-tuning. It iterates through the model, freezes all of its original parameters, and replaces the target nn.Linear
layers with our custom LoRALayer
modules. Finally, it unfreezes only the newly added lora_
parameters, ensuring that the optimizer will only update the adapter weights.
def get_lora_model(
model: DecoderTransformer,
r: int,
alpha: float,
target_modules: list[str]
):
# Freeze all layers of the model
for parameter in model.parameters():
parameter.requires_grad = False
# Replace target modules with LoRALayers
for name, module in model.named_modules():
layer_name = name.split(".")[-1]
if layer_name in target_modules:
parent_layer_name = ".".join(name.split(".")[:-1])
parent_module = model.get_submodule(parent_layer_name)
if isinstance(module, nn.Linear):
lora_layer = LoRALayer(module, r, alpha)
setattr(parent_module, layer_name, lora_layer)
# Unfreeze only the LoRA weights for training
for name, param in model.named_parameters():
if "lora_" in name:
param.requires_grad = True
return model
After fine-tuning is complete, the merge_and_unload
function iterates through the model again, calling the .merge()
method on each LoRALayer
to add the learned update to the original weights. It then replaces the LoRALayer
module with the now-updated, standard nn.Linear
layer, reverting the model to its original architecture, without any LoRALayer
modules.
def merge_and_unload(model: DecoderTransformer) -> DecoderTransformer:
for name, module in model.named_modules():
if isinstance(module, LoRALayer):
module.merge()
parent_name = ".".join(name.split('.')[:-1])
parent_module = model.get_submodule(parent_name)
layer_name = name.split('.')[-1]
setattr(parent_module, layer_name, module.W)
return model
Visit this Google Collab notebook for a working example with LoRA.