Skip to content

Commit 2813b9c

Browse files
authored
FEAT Add DeLoRA (#2780)
Implements DeLoRA: "Decoupling Angles and Strength in Low-rank Adaptation" (https://huggingface.co/papers/2503.18225). Similar to DoRA, DeLoRA decouples the angular learning from the adaptation strength, but it also allows to limit the norm of the change. This way, DeLoRA promises to reduce the risk of catastrophic forgetting and to be more robust to hyper-parameter settings such as the learning rate.
1 parent 8d8aa0b commit 2813b9c

File tree

23 files changed

+1059
-2
lines changed

23 files changed

+1059
-2
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@
136136
title: RoAd
137137
- local: package_reference/waveft
138138
title: WaveFT
139+
- local: package_reference/delora
140+
title: DeLoRA
139141

140142
title: Adapters
141143
- sections:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# DeLoRA: Decoupled Low-rank Adaptation
18+
[DeLoRA](https://huggingface.co/papers/2503.18225) is a parameter-efficient fine-tuning technique that implicitly maintains a Frobenius boundary with respect to the pretrained weights by normalizing and scaling learnable low-rank matrices. This effectively decouples the learning of directions (BA term) and magnitude (boundary term) of the weight updates, avoiding catastrophic shifts in the adapted weights and enhancing robustness to hyperparameter choices.
19+
20+
Note:
21+
- use 10-100x larger learning rate than standard LoRA variants (typical values from 1e-3/1e-2/..)
22+
- do not set a too small initial boundary parameter lambda (typical values are around 10/15/..)
23+
- setting different lambdas to different layers is possible
24+
25+
The abstract from the paper is:
26+
27+
> Parameter-Efficient FineTuning (PEFT) methods have recently gained significant popularity thanks to the widespread availability of large-scale pretrained models. These methods allow for quick adaptation to downstream tasks with minimal computational cost. However, popular finetuning methods such as LoRA exhibit limited robustness when it comes to hyperparameter choices or extended training regimes, preventing optimal out-of-the-box performance. In contrast, bounded approaches, such as ETHER, provide greater robustness but are limited to extremely low-rank adaptations and fixed-strength transformations, reducing their adaptation expressive power. In this work, we propose Decoupled Low-rank Adaptation (DeLoRA), a novel finetuning method that normalizes and scales learnable low-rank matrices. By bounding the distance of the transformation, DeLoRA effectively decouples the angular learning from the adaptation strength, enhancing robustness without compromising performance. Through evaluations on subject-driven image generation, natural language understanding, and instruction tuning, we show that DeLoRA matches or surpasses performance of competing PEFT methods, while exhibiting stronger robustness.
28+
29+
## DeloraConfig
30+
31+
[[autodoc]] tuners.delora.config.DeloraConfig
32+
33+
## DeloraModel
34+
35+
[[autodoc]] tuners.delora.model.DeloraModel
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# DeLoRA: Decoupled Low-Rank Adaptation
2+
3+
## Introduction
4+
[DeLoRA](https://huggingface.co/papers/2503.18225) tackles finetuning in a Frobenius-norm bounded setup: this allows to prevent divergence from the pretrained model, effectively decoupling the learning of angles and magnitudes.
5+
6+
This is done by (i) normalization of the BA low-rank matrices, which bound the updates' Frobenius norm, (ii) learnable scaling lambda, which controls the update's boundary/magnitude, (iii) layer-wise scaling of ||W||, to adapt each update's norm to the original weights' norm.
7+
8+
## Quick start
9+
10+
With respect to your standard PEFT training procedure with LoRA, simply swap your `LoraConfig` for a `DeloraConfig`. Note however that `lora_alpha` parameter is replaced by `delora_lambda` parameter which sets an upper bound to the Frobenius norm of the weight change.
11+
12+
```python
13+
import torch
14+
from peft import DeloraConfig, get_peft_model
15+
from transformers import AutoTokenizer, AutoModelForCausalLM
16+
from trl import SFTConfig, SFTTrainer
17+
from datasets import load_dataset
18+
19+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto")
20+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
21+
tokenizer.pad_token_id = tokenizer.eos_token_id
22+
delora_config = DeloraConfig(r=32, delora_lambda=15)
23+
24+
peft_model = get_peft_model(model, delora_config)
25+
peft_model.print_trainable_parameters()
26+
27+
dataset = load_dataset("imdb", split="train[:1%]")
28+
29+
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
30+
trainer = SFTTrainer(
31+
model=peft_model,
32+
args=training_args,
33+
train_dataset=dataset,
34+
processing_class=tokenizer,
35+
)
36+
trainer.train()
37+
peft_model.save_pretrained("delora-llama-3-8b")
38+
```
39+
40+
To utilize the fine-tuned DeLoRA modules, simply run the following command:
41+
```python
42+
import torch
43+
from peft import PeftModel
44+
from transformers import AutoModelForCausalLM
45+
46+
model = AutoModelForCausalLM.from_pretrained(
47+
"meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto"
48+
)
49+
peft_model = PeftModel.from_pretrained(model, "delora-llama-3-8b")
50+
```
51+
52+
## Advanced Usage
53+
In this script the default DeLoRA layers are the query and value layers of the Llama model. Adding adapters on more layers will increase memory usage. If you wish to choose a different set of layers for DeLoRA to be applied on, you can simply define it using:
54+
```bash
55+
python examples/delora_finetuning/delora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --delora_target_modules "q_proj,k_proj,v_proj,o_proj"
56+
```
57+
58+
Using different lambdas for different layers is also possible by setting `lambda_pattern`.
59+
60+
### Fine-tune
61+
```bash
62+
python delora_finetuning.py \
63+
--base_model "PATH_TO_MODEL" \
64+
--data_path "PATH_TO_DATASET" \
65+
--output_dir "PATH_TO_OUTPUT_DIR" \
66+
--batch_size 1 \
67+
--num_epochs 3 \
68+
--learning_rate 3e-3 \
69+
--cutoff_len 512 \
70+
--val_set_size 500 \
71+
--eval_step 10 \
72+
--save_step 100 \
73+
--device "auto" \
74+
--rank 32 \
75+
--delora_lambda 15 \
76+
--module_dropout 0.1 \
77+
--delora_target_modules "q_proj,v_proj" \
78+
--hub_model_id "YOUR_HF_REPO" \
79+
--push_to_hub
80+
```
81+
82+
## Additional Notes
83+
### Best practices
84+
- use 10-100x larger learning rate than standard LoRA variants (typical values from 1e-3/1e-2/..)
85+
- do not set a too small initial boundary parameter lambda (typical values are around 10/15/..)
86+
87+
88+
### DeLoRA vs DoRA
89+
DeLoRA might feel quite similar to DoRA (given the similar target of decoupling angular from magnitude learning), however it presents key differences: (i) DoRA applies normalization and scaling operations on the fully finetuned weights ($W + \Delta W$), (ii) DoRA's normalization operation is performed on the column space of the weight matrices.
90+
91+
Conversely DeLoRA (i) introduces the normalization and scaling operations directly on the weight updates $\Delta W$, better preventing divergence from the pretrained model, and (ii) normalizes the inner low-dimensional space, which enforces a Frobenius-norm boundary to the weight updates.
92+
93+
94+
## Citation
95+
```
96+
@inproceedings{bini2025decouplinganglesstrengthlowrank,
97+
title={Decoupling Angles and Strength in Low-rank Adaptation},
98+
author={Massimo Bini and Leander Girrbach and Zeynep Akata},
99+
year={2025},
100+
booktitle={International Conference on Learning Representations (ICLR)},
101+
}
102+
```
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# This script is based on examples/randlora_finetuning/randlora_finetuning.py
2+
import os
3+
4+
import torch
5+
from datasets import load_dataset
6+
from transformers import (
7+
AutoModelForCausalLM,
8+
AutoTokenizer,
9+
DataCollatorForLanguageModeling,
10+
Trainer,
11+
TrainingArguments,
12+
)
13+
14+
from peft import DeloraConfig, get_peft_model
15+
16+
17+
def train_model(
18+
base_model: str,
19+
data_path: str,
20+
output_dir: str,
21+
batch_size: int,
22+
num_epochs: int,
23+
learning_rate: float,
24+
cutoff_len: int,
25+
val_set_size: int,
26+
eval_step: int,
27+
save_step: int,
28+
device: str,
29+
rank: int,
30+
delora_lambda: int,
31+
module_dropout: float,
32+
target_modules: str,
33+
hub_model_id: str,
34+
push_to_hub: bool,
35+
):
36+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
37+
hf_token = os.getenv("HF_TOKEN")
38+
39+
# Setup device
40+
device = torch.device(device)
41+
print(f"Using device: {device}")
42+
43+
# load tokenizer
44+
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)
45+
46+
# Compute type
47+
device_type = device.type
48+
device_module = getattr(torch, device_type, torch.cuda)
49+
bf16_supported = device_module.is_available() and device_module.is_bf16_supported()
50+
dtype = torch.bfloat16 if bf16_supported else torch.float32
51+
52+
# Load the base model
53+
model = AutoModelForCausalLM.from_pretrained(
54+
base_model,
55+
dtype=dtype,
56+
)
57+
58+
# DeLoRA config for the PEFT model
59+
peft_config = DeloraConfig(
60+
r=rank,
61+
delora_lambda=delora_lambda,
62+
target_modules=(target_modules.split(",") if target_modules else None),
63+
module_dropout=module_dropout,
64+
bias="none",
65+
)
66+
67+
# get the peft model with DeLoRA config
68+
model = get_peft_model(model, peft_config)
69+
70+
model.to(device) # MODEL TO ACCELERATOR
71+
tokenizer.pad_token = tokenizer.eos_token
72+
73+
# Load the dataset
74+
dataset = load_dataset(data_path)
75+
76+
def tokenize_function(examples):
77+
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len)
78+
inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task
79+
return inputs
80+
81+
# Tokenize the dataset and prepare for training
82+
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
83+
84+
# Data collator to dynamically pad the batched examples
85+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
86+
87+
# Compute the total amount of training step for warmup
88+
max_steps = int((len(dataset) // batch_size) * num_epochs)
89+
90+
# Define training arguments
91+
training_args = TrainingArguments(
92+
output_dir=output_dir,
93+
num_train_epochs=num_epochs,
94+
per_device_train_batch_size=batch_size,
95+
per_device_eval_batch_size=batch_size,
96+
warmup_steps=int(max_steps * 0.1), # 10% of total trainig steps
97+
weight_decay=0.0,
98+
logging_steps=eval_step,
99+
save_steps=save_step,
100+
save_total_limit=2,
101+
push_to_hub=push_to_hub,
102+
hub_model_id=hub_model_id,
103+
gradient_accumulation_steps=16,
104+
learning_rate=learning_rate,
105+
hub_token=hf_token,
106+
label_names=["labels"],
107+
)
108+
109+
# Clear accelerator cache to free memory
110+
device_module.empty_cache()
111+
112+
# Initialize the Trainer
113+
trainer = Trainer(
114+
model=model,
115+
args=training_args,
116+
train_dataset=tokenized_datasets["train"],
117+
eval_dataset=tokenized_datasets["test"],
118+
data_collator=data_collator,
119+
)
120+
121+
# Start model training
122+
trainer.train()
123+
124+
# Save and push the trained model and tokenizer
125+
if push_to_hub:
126+
# Push the main model to the hub
127+
trainer.push_to_hub(commit_message="Fine-tuned model")
128+
129+
# Save the model and tokenizer locally
130+
model.save_pretrained(output_dir)
131+
tokenizer.save_pretrained(output_dir)
132+
133+
134+
if __name__ == "__main__":
135+
import argparse
136+
137+
parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DeLoRA")
138+
parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name")
139+
parser.add_argument(
140+
"--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name"
141+
)
142+
parser.add_argument(
143+
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
144+
)
145+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
146+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
147+
parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate")
148+
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
149+
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
150+
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
151+
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
152+
parser.add_argument("--device", type=str, default="auto", help="Device to use for training")
153+
parser.add_argument("--rank", type=int, default=32, help="DeLoRA basis rank")
154+
parser.add_argument("--delora_lambda", type=int, default=640, help="DeLoRA alpha")
155+
parser.add_argument("--module_dropout", type=float, default=0.05, help="DeLoRA dropout rate")
156+
parser.add_argument(
157+
"--target_modules", type=str, default=None, help="Comma-separated list of target modules for DeLoRA"
158+
)
159+
parser.add_argument(
160+
"--hub_model_id",
161+
type=str,
162+
default="path/to/repo",
163+
help="Repository name to push the model on the Hugging Face Hub",
164+
)
165+
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
166+
args = parser.parse_args()
167+
168+
if args.device == "auto":
169+
args.device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
170+
171+
train_model(
172+
base_model=args.base_model,
173+
data_path=args.data_path,
174+
output_dir=args.output_dir,
175+
batch_size=args.batch_size,
176+
num_epochs=args.num_epochs,
177+
learning_rate=args.learning_rate,
178+
cutoff_len=args.cutoff_len,
179+
val_set_size=args.val_set_size,
180+
eval_step=args.eval_step,
181+
save_step=args.save_step,
182+
device=args.device,
183+
rank=args.rank,
184+
delora_lambda=args.delora_lambda,
185+
module_dropout=args.module_dropout,
186+
target_modules=args.target_modules,
187+
hub_model_id=args.hub_model_id,
188+
push_to_hub=args.push_to_hub,
189+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"lambda_pattern": {},
3+
"auto_mapping": null,
4+
"base_model_name_or_path": null,
5+
"bias": "none",
6+
"exclude_modules": null,
7+
"inference_mode": false,
8+
"init_weights": true,
9+
"layers_pattern": null,
10+
"layers_to_transform": null,
11+
"delora_lambda": 15,
12+
"module_dropout": 0.0,
13+
"modules_to_save": null,
14+
"peft_type": "DELORA",
15+
"r": 32,
16+
"rank_pattern": {},
17+
"revision": null,
18+
"target_modules": null,
19+
"task_type": "CAUSAL_LM"
20+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"optimizer_kwargs": {
3+
"lr": 1e-3
4+
}
5+
}
6+

src/peft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
C3AModel,
6060
CPTConfig,
6161
CPTEmbedding,
62+
DeloraConfig,
63+
DeloraModel,
6264
EvaConfig,
6365
FourierFTConfig,
6466
FourierFTModel,
@@ -154,6 +156,8 @@
154156
"C3AModel",
155157
"CPTConfig",
156158
"CPTEmbedding",
159+
"DeloraConfig",
160+
"DeloraModel",
157161
"EvaConfig",
158162
"FourierFTConfig",
159163
"FourierFTModel",

0 commit comments

Comments
 (0)