LoRA (Low-Rank Adaptation of Large Language Models) is a technique designed to efficiently fine-tune large language models (LLMs) by introducing trainable low-rank matrices while freezing the original model weights. This method drastically reduces the computational and memory costs associated with training massive models like GPT, BERT, or others.
How LoRA Works
Concept of Low-Rank Decomposition:
LoRA assumes that the updates (or weight changes) required for fine-tuning are low-rank matrices.
It decomposes the weight update matrices into a pair of smaller matrices with low rank. These smaller matrices are trainable, while the original model weights remain frozen.Integration:
LoRA adds these low-rank matrices to specific layers of the model (e.g., transformer layers), particularly in the projection layers of attention mechanisms.
During inference or training, the low-rank matrices are combined with the original frozen weights.
Benefits of LoRA
Reduced Compute Requirements:
Since only the low-rank matrices are trainable, the memory footprint and computational cost are significantly lower.
This is especially helpful for fine-tuning large models on resource-constrained devices (e.g., GPUs with limited VRAM).
Efficiency:
Fine-tuning with LoRA is often faster because it updates fewer parameters compared to full-model fine-tuning.
Scalability:
Multiple tasks can be fine-tuned on the same base model by storing different sets of LoRA parameters without duplicating the entire model.
Compatibility:
LoRA is modular. You can mix and match pretrained LoRA adapters with different base models or tasks.
Let's jump to the coding.
To install the necessary packages for the script, you can use the following pip install
command:
pip install transformers peft datasets torch
This command will install the transformers, peft, datasets, and torch libraries, which are required for the script to run.
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import torch
# Step 1: Load the base model and tokenizer
model_name = "gpt2" # Replace with the desired model
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Step 2: Configure and apply LoRA
lora_config = LoraConfig(
r=8, # Low-rank dimension
lora_alpha=32, # Scaling factor
target_modules=["c_attn"], # Target modules for LoRA (check model's architecture)
lora_dropout=0.1, # Dropout for LoRA layers
bias="none" # Whether to tune biases
)
model = get_peft_model(model, lora_config)
# Step 3: Load and preprocess the dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Step 4: Define training arguments
training_args = TrainingArguments(
output_dir="./results",
eval_strategy="epoch", # Updated to use `eval_strategy`
learning_rate=5e-4,
per_device_train_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
save_strategy="epoch",
save_total_limit=2
)
# Step 5: Initialize the Trainer with a custom compute_loss function
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
logits = outputs.get("logits")
labels = inputs.get("input_ids")
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return (loss, outputs) if return_outputs else loss
def training_step(self, model, inputs, *args, **kwargs):
model.train()
inputs = self._prepare_inputs(inputs)
loss = self.compute_loss(model, inputs)
# Use the optimizer and scheduler provided by the Trainer
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.lr_scheduler.step()
return loss
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"]
)
# Step 6: Fine-tune the model with LoRA
trainer.train()
# Step 7: Save the fine-tuned model
model.save_pretrained("./lora_finetuned_model")
tokenizer.save_pretrained("./lora_finetuned_model")
print("Fine-tuning complete! The model is saved at './lora_finetuned_model'.")
Here's a quick explanation of the code:
Imports:
The script imports necessary libraries and modules, including transformers, peft, datasets, and torch.
Load Model and Tokenizer:
The base model (gpt2) and tokenizer are loaded using AutoModelForCausalLM and AutoTokenizer from the transformers library.
If the tokenizer does not have a padding token, it sets the padding token to the end-of-sequence token.
Configure and Apply LoRA:
A LoraConfig object is created with specific parameters for low-rank adaptation (LoRA).
The base model is modified using get_peft_model to apply the LoRA configuration.
Load and Preprocess Dataset:
The wikitext-2-raw-v1 dataset is loaded using the datasets library.
A tokenize_function is defined to tokenize the dataset examples.
The dataset is tokenized using the map method.
Define Training Arguments:
Training arguments are defined using TrainingArguments from the transformers library, specifying parameters like output directory, evaluation strategy, learning rate, batch size, number of epochs, weight decay, logging directory, and save strategy.
Custom Trainer Class:
A CustomTrainer class is defined, inheriting from Trainer.
The compute_loss method is overridden to compute the loss using cross-entropy loss.
The training_step method is overridden to perform a training step, including zeroing gradients, backpropagation, and optimizer step.
Initialize Trainer:
An instance of CustomTrainer is created with the model, training arguments, and tokenized datasets.
Fine-tune the Model:
The train method of the CustomTrainer instance is called to fine-tune the model.
Save the Fine-tuned Model:
The fine-tuned model and tokenizer are saved to the specified directory.
Print Completion Message:
A message is printed to indicate that fine-tuning is complete and the model is saved.
This script fine-tunes a GPT-2 model using LoRA on the Wikitext-2 dataset and saves the fine-tuned model and tokenizer.
Happy coding!
Top comments (0)