|
| 1 | +# This script is based on examples/randlora_finetuning/randlora_finetuning.py |
| 2 | +import os |
| 3 | + |
| 4 | +import torch |
| 5 | +from datasets import load_dataset |
| 6 | +from transformers import ( |
| 7 | + AutoModelForCausalLM, |
| 8 | + AutoTokenizer, |
| 9 | + DataCollatorForLanguageModeling, |
| 10 | + Trainer, |
| 11 | + TrainingArguments, |
| 12 | +) |
| 13 | + |
| 14 | +from peft import DeloraConfig, get_peft_model |
| 15 | + |
| 16 | + |
| 17 | +def train_model( |
| 18 | + base_model: str, |
| 19 | + data_path: str, |
| 20 | + output_dir: str, |
| 21 | + batch_size: int, |
| 22 | + num_epochs: int, |
| 23 | + learning_rate: float, |
| 24 | + cutoff_len: int, |
| 25 | + val_set_size: int, |
| 26 | + eval_step: int, |
| 27 | + save_step: int, |
| 28 | + device: str, |
| 29 | + rank: int, |
| 30 | + delora_lambda: int, |
| 31 | + module_dropout: float, |
| 32 | + target_modules: str, |
| 33 | + hub_model_id: str, |
| 34 | + push_to_hub: bool, |
| 35 | +): |
| 36 | + os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| 37 | + hf_token = os.getenv("HF_TOKEN") |
| 38 | + |
| 39 | + # Setup device |
| 40 | + device = torch.device(device) |
| 41 | + print(f"Using device: {device}") |
| 42 | + |
| 43 | + # load tokenizer |
| 44 | + tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) |
| 45 | + |
| 46 | + # Compute type |
| 47 | + device_type = device.type |
| 48 | + device_module = getattr(torch, device_type, torch.cuda) |
| 49 | + bf16_supported = device_module.is_available() and device_module.is_bf16_supported() |
| 50 | + dtype = torch.bfloat16 if bf16_supported else torch.float32 |
| 51 | + |
| 52 | + # Load the base model |
| 53 | + model = AutoModelForCausalLM.from_pretrained( |
| 54 | + base_model, |
| 55 | + dtype=dtype, |
| 56 | + ) |
| 57 | + |
| 58 | + # DeLoRA config for the PEFT model |
| 59 | + peft_config = DeloraConfig( |
| 60 | + r=rank, |
| 61 | + delora_lambda=delora_lambda, |
| 62 | + target_modules=(target_modules.split(",") if target_modules else None), |
| 63 | + module_dropout=module_dropout, |
| 64 | + bias="none", |
| 65 | + ) |
| 66 | + |
| 67 | + # get the peft model with DeLoRA config |
| 68 | + model = get_peft_model(model, peft_config) |
| 69 | + |
| 70 | + model.to(device) # MODEL TO ACCELERATOR |
| 71 | + tokenizer.pad_token = tokenizer.eos_token |
| 72 | + |
| 73 | + # Load the dataset |
| 74 | + dataset = load_dataset(data_path) |
| 75 | + |
| 76 | + def tokenize_function(examples): |
| 77 | + inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len) |
| 78 | + inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task |
| 79 | + return inputs |
| 80 | + |
| 81 | + # Tokenize the dataset and prepare for training |
| 82 | + tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) |
| 83 | + |
| 84 | + # Data collator to dynamically pad the batched examples |
| 85 | + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
| 86 | + |
| 87 | + # Compute the total amount of training step for warmup |
| 88 | + max_steps = int((len(dataset) // batch_size) * num_epochs) |
| 89 | + |
| 90 | + # Define training arguments |
| 91 | + training_args = TrainingArguments( |
| 92 | + output_dir=output_dir, |
| 93 | + num_train_epochs=num_epochs, |
| 94 | + per_device_train_batch_size=batch_size, |
| 95 | + per_device_eval_batch_size=batch_size, |
| 96 | + warmup_steps=int(max_steps * 0.1), # 10% of total trainig steps |
| 97 | + weight_decay=0.0, |
| 98 | + logging_steps=eval_step, |
| 99 | + save_steps=save_step, |
| 100 | + save_total_limit=2, |
| 101 | + push_to_hub=push_to_hub, |
| 102 | + hub_model_id=hub_model_id, |
| 103 | + gradient_accumulation_steps=16, |
| 104 | + learning_rate=learning_rate, |
| 105 | + hub_token=hf_token, |
| 106 | + label_names=["labels"], |
| 107 | + ) |
| 108 | + |
| 109 | + # Clear accelerator cache to free memory |
| 110 | + device_module.empty_cache() |
| 111 | + |
| 112 | + # Initialize the Trainer |
| 113 | + trainer = Trainer( |
| 114 | + model=model, |
| 115 | + args=training_args, |
| 116 | + train_dataset=tokenized_datasets["train"], |
| 117 | + eval_dataset=tokenized_datasets["test"], |
| 118 | + data_collator=data_collator, |
| 119 | + ) |
| 120 | + |
| 121 | + # Start model training |
| 122 | + trainer.train() |
| 123 | + |
| 124 | + # Save and push the trained model and tokenizer |
| 125 | + if push_to_hub: |
| 126 | + # Push the main model to the hub |
| 127 | + trainer.push_to_hub(commit_message="Fine-tuned model") |
| 128 | + |
| 129 | + # Save the model and tokenizer locally |
| 130 | + model.save_pretrained(output_dir) |
| 131 | + tokenizer.save_pretrained(output_dir) |
| 132 | + |
| 133 | + |
| 134 | +if __name__ == "__main__": |
| 135 | + import argparse |
| 136 | + |
| 137 | + parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DeLoRA") |
| 138 | + parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name") |
| 139 | + parser.add_argument( |
| 140 | + "--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name" |
| 141 | + ) |
| 142 | + parser.add_argument( |
| 143 | + "--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" |
| 144 | + ) |
| 145 | + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") |
| 146 | + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") |
| 147 | + parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate") |
| 148 | + parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization") |
| 149 | + parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size") |
| 150 | + parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") |
| 151 | + parser.add_argument("--save_step", type=int, default=100, help="Save step interval") |
| 152 | + parser.add_argument("--device", type=str, default="auto", help="Device to use for training") |
| 153 | + parser.add_argument("--rank", type=int, default=32, help="DeLoRA basis rank") |
| 154 | + parser.add_argument("--delora_lambda", type=int, default=640, help="DeLoRA alpha") |
| 155 | + parser.add_argument("--module_dropout", type=float, default=0.05, help="DeLoRA dropout rate") |
| 156 | + parser.add_argument( |
| 157 | + "--target_modules", type=str, default=None, help="Comma-separated list of target modules for DeLoRA" |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--hub_model_id", |
| 161 | + type=str, |
| 162 | + default="path/to/repo", |
| 163 | + help="Repository name to push the model on the Hugging Face Hub", |
| 164 | + ) |
| 165 | + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub") |
| 166 | + args = parser.parse_args() |
| 167 | + |
| 168 | + if args.device == "auto": |
| 169 | + args.device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" |
| 170 | + |
| 171 | + train_model( |
| 172 | + base_model=args.base_model, |
| 173 | + data_path=args.data_path, |
| 174 | + output_dir=args.output_dir, |
| 175 | + batch_size=args.batch_size, |
| 176 | + num_epochs=args.num_epochs, |
| 177 | + learning_rate=args.learning_rate, |
| 178 | + cutoff_len=args.cutoff_len, |
| 179 | + val_set_size=args.val_set_size, |
| 180 | + eval_step=args.eval_step, |
| 181 | + save_step=args.save_step, |
| 182 | + device=args.device, |
| 183 | + rank=args.rank, |
| 184 | + delora_lambda=args.delora_lambda, |
| 185 | + module_dropout=args.module_dropout, |
| 186 | + target_modules=args.target_modules, |
| 187 | + hub_model_id=args.hub_model_id, |
| 188 | + push_to_hub=args.push_to_hub, |
| 189 | + ) |
0 commit comments