Skip to content

DeepSpeed ZeRO-3 Resizing untied output head after resize_token_embeddings breaks (LoRA + ZeRO-3) #41959

@XiangZhang-zx

Description

@XiangZhang-zx

System Info

transformers==4.57.0
accelerate==1.7.0
deepspeed==0.18.0
peft==0.17.1

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm encountering a RuntimeError when trying to resize token embeddings (both wte and custom output layer ff_out) in a model trained with DeepSpeed ZeRO-3 + LoRA.

The error occurs when manually resizing the output layer after calling model.resize_token_embeddings().

Resizing the vocabulary of an untied LLaDA model under DeepSpeed ZeRO-3 fails once the output projection (ff_out) is rebuilt.
old_ff_out.weight is exposed as an empty shard, so copying weights raises a shape mismatch.
If the layer is rebuilt, training later crashes because the new Linear module lacks the ZeRO bookkeeping attributes (ds_grads_remaining, applied_pre_backward_ref_cnt).

  1. DeepSpeed Config (deepspeed_zero3.yaml):
compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
mixed_precision: bf16
deepspeed_config:
  zero_stage: 3
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
num_processes: 2
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
import torch

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("model_name")
tokenizer = AutoTokenizer.from_pretrained("model_name")

# Add new tokens
new_tokens = ["<token1>", "<token2>", ...]
tokenizer.add_tokens(new_tokens)

# Resize embeddings (this works fine)
model.resize_token_embeddings(len(tokenizer))

# For models with separate output layer (e.g., LLaDA with weight_tying=False)
# Need to manually resize the output layer
if hasattr(model.model.transformer, 'ff_out'):
    old_ff_out = model.model.transformer.ff_out
    old_out_features = old_ff_out.out_features
    
    # Create new output layer
    new_ff_out = torch.nn.Linear(
        old_ff_out.in_features,
        len(tokenizer),
        bias=old_ff_out.bias is not None,
        device=old_ff_out.weight.device,
        dtype=old_ff_out.weight.dtype
    )
    
    # ❌ ERROR HERE: old_ff_out.weight is empty (shape [0]) in ZeRO-3
    with torch.no_grad():
        new_ff_out.weight[:old_out_features] = old_ff_out.weight  # RuntimeError!
    
    model.model.transformer.ff_out = new_ff_out

# Apply LoRA
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)

# Train
trainer = Trainer(model=model, args=training_args, ...)
trainer.train()  # ❌ Error: 'Linear' object has no attribute 'ds_grads_remaining'

Expected behavior

❌ Actual behavior
Copy step fails with
RuntimeError: Target sizes [126464, 4096] vs tensor [0]
After hacking around it, training crashes with
AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'
✅ Expected behavior
resize_token_embeddings should let us expand the vocabulary and rebuild untied heads while keeping the model compatible with ZeRO-3 and LoRA. Ideally, there should be an official way to re-register new modules with ZeRO-3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions